Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
57e6fd46
"...composable_kernel_rocm.git" did not exist on "f95267f166927bee1d806cefbdc142b2e35f640f"
Commit
57e6fd46
authored
Jan 21, 2025
by
Adam Osewski
Browse files
Adding shuffled encoding patterns.
parent
c400e5b3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
78 additions
and
46 deletions
+78
-46
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
+78
-46
No files found.
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
View file @
57e6fd46
...
...
@@ -14,7 +14,7 @@
namespace
ck_tile
{
/**
* @brief Enumeration describing tile distribution patterns.
* @brief Enumeration describing
static
tile distribution patterns.
*
*/
enum
struct
tile_distribution_pattern
...
...
@@ -34,8 +34,6 @@ enum struct tile_distribution_pattern
*
*/
block_raked
,
// TODO pattern taking into account MFMA attributes:
// block_fmha_pipeline_qx_ks_vs_custom_policy.hpp::51 MakeQDramTileDistribution()
};
struct
TileDistributionEcodingPattern
...
...
@@ -73,27 +71,27 @@ struct TileDistributionEncodingPattern2D<BlockSize,
:
public
TileDistributionEcodingPattern
{
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
constexpr
index_t
warp_size
=
get_warp_size
();
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
constexpr
index_t
X1
=
VecSize
;
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
static
constexpr
index_t
warp_size
=
get_warp_size
();
static
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
static
constexpr
index_t
X1
=
VecSize
;
static
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
// # of rows in Y dim accessed by single wavefront in one iteration
constexpr
index_t
Y1
=
warp_size
/
X0
;
static_assert
(
X0
*
Y1
==
warp_size
,
"X0 * Y1 must cover whole wavefront!"
);
// # of rows in Y dim accessed by single wavefront in one iteration
static
constexpr
index_t
Y1
=
warp_size
/
X0
;
static_assert
(
X0
*
Y1
==
warp_size
,
"X0 * Y1 must cover whole wavefront!"
);
constexpr
index_t
Y0
=
num_warps
;
// YPerWarp = YPerTile / Y0;
// Y2 = YPerWarp / Y1;
constexpr
index_t
Y2
=
YPerTile
/
(
Y1
*
Y0
);
// # of iters within wavefront
static
constexpr
index_t
Y0
=
num_warps
;
// YPerWarp = YPerTile / Y0;
// Y2 = YPerWarp / Y1;
static
constexpr
index_t
Y2
=
YPerTile
/
(
Y1
*
Y0
);
// # of iters within wavefront
static_assert
(
X0
*
Y1
*
Y0
==
BlockSize
,
"X0 * warp_ys * Y0 must cover whole workgroup!"
);
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
static_assert
(
X0
*
Y1
*
Y0
==
BlockSize
,
"X0 * warp_ys * Y0 must cover whole workgroup!"
);
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
...
...
@@ -102,6 +100,17 @@ struct TileDistributionEncodingPattern2D<BlockSize,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffled2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
X0
,
X1
>
,
sequence
<
Y0
,
Y1
,
Y2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
{});
}
};
// Warp raked
...
...
@@ -113,23 +122,24 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_pattern
::
warp_raked
>
:
public
TileDistributionEcodingPattern
{
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
constexpr
index_t
warp_size
=
get_warp_size
();
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
constexpr
index_t
X1
=
VecSize
;
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
constexpr
index_t
Y2
=
warp_size
/
X0
;
// # of rows in Y dim to cover whole wavefront
static_assert
(
X0
*
Y2
==
warp_size
,
"X0 * Y2 must cover whole wavefront!"
);
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
static
constexpr
index_t
warp_size
=
get_warp_size
();
static
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
static
constexpr
index_t
X1
=
VecSize
;
static
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
constexpr
index_t
Y
0
=
num_warps
;
static_assert
(
X0
*
Y2
*
Y0
==
BlockS
ize
,
"X0 * Y2
* Y1
must cover whole w
orkgroup
!"
);
static
constexpr
index_t
Y
2
=
warp_size
/
X0
;
// # of rows in Y dim to cover whole wavefront
static_assert
(
X0
*
Y2
==
warp_s
ize
,
"X0 * Y2 must cover whole w
avefront
!"
);
constexpr
index_t
Y
1
=
YPerTile
/
(
Y2
*
Y0
);
// # of iters within wavefront
static_assert
(
Y
0
*
Y
1
*
Y
2
==
YPerTile
,
"Y0, Y1,
Y
2
must cover whole
YPerTile
"
);
static
constexpr
index_t
Y
0
=
num_warps
;
static_assert
(
X
0
*
Y
2
*
Y
0
==
BlockSize
,
"X0 * Y2 *
Y
1
must cover whole
workgroup!
"
);
static
constexpr
index_t
Y1
=
YPerTile
/
(
Y2
*
Y0
);
// # of iters within wavefront
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
...
...
@@ -138,6 +148,17 @@ struct TileDistributionEncodingPattern2D<BlockSize,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffled2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
X0
,
X1
>
,
sequence
<
Y0
,
Y1
,
Y2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
};
// Block raked
...
...
@@ -150,21 +171,21 @@ struct TileDistributionEncodingPattern2D<BlockSize,
:
public
TileDistributionEcodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
static
constexpr
index_t
warp_size
=
get_warp_size
();
static
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
static
constexpr
index_t
X1
=
VecSize
;
static
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
static
constexpr
index_t
Y2
=
warp_size
/
X0
;
// # of rows in Y dim to cover whole wavefront
static_assert
(
X0
*
Y2
==
warp_size
,
"X0 * Y2 must cover whole wavefront!"
);
static
constexpr
index_t
Y1
=
num_warps
;
static_assert
(
X0
*
Y2
*
Y1
==
BlockSize
,
"X0 * Y2 * Y1 must cover whole workgroup!"
);
static
constexpr
index_t
Y0
=
YPerTile
/
(
Y2
*
Y1
);
// # of iters
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
constexpr
index_t
warp_size
=
get_warp_size
();
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
constexpr
index_t
X1
=
VecSize
;
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
constexpr
index_t
Y2
=
warp_size
/
X0
;
// # of rows in Y dim to cover whole wavefront
static_assert
(
X0
*
Y2
==
warp_size
,
"X0 * Y2 must cover whole wavefront!"
);
constexpr
index_t
Y1
=
num_warps
;
static_assert
(
X0
*
Y2
*
Y1
==
BlockSize
,
"X0 * Y2 * Y1 must cover whole workgroup!"
);
constexpr
index_t
Y0
=
YPerTile
/
(
Y2
*
Y1
);
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
...
...
@@ -173,6 +194,17 @@ struct TileDistributionEncodingPattern2D<BlockSize,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffled2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
X0
,
X1
>
,
sequence
<
Y0
,
Y1
,
Y2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
0
>>
{});
}
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment