Commit 57e6fd46 authored by Adam Osewski's avatar Adam Osewski
Browse files

Adding shuffled encoding patterns.

parent c400e5b3
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace ck_tile { namespace ck_tile {
/** /**
* @brief Enumeration describing tile distribution patterns. * @brief Enumeration describing static tile distribution patterns.
* *
*/ */
enum struct tile_distribution_pattern enum struct tile_distribution_pattern
...@@ -34,8 +34,6 @@ enum struct tile_distribution_pattern ...@@ -34,8 +34,6 @@ enum struct tile_distribution_pattern
* *
*/ */
block_raked, block_raked,
// TODO pattern taking into account MFMA attributes:
// block_fmha_pipeline_qx_ks_vs_custom_policy.hpp::51 MakeQDramTileDistribution()
}; };
struct TileDistributionEcodingPattern struct TileDistributionEcodingPattern
...@@ -73,27 +71,27 @@ struct TileDistributionEncodingPattern2D<BlockSize, ...@@ -73,27 +71,27 @@ struct TileDistributionEncodingPattern2D<BlockSize,
: public TileDistributionEcodingPattern : public TileDistributionEcodingPattern
{ {
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
{ static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! static constexpr index_t warp_size = get_warp_size();
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); static constexpr index_t num_warps = BlockSize / get_warp_size();
constexpr index_t warp_size = get_warp_size(); static constexpr index_t X1 = VecSize;
constexpr index_t num_warps = BlockSize / get_warp_size(); static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
constexpr index_t X1 = VecSize;
constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
// # of rows in Y dim accessed by single wavefront in one iteration // # of rows in Y dim accessed by single wavefront in one iteration
constexpr index_t Y1 = warp_size / X0; static constexpr index_t Y1 = warp_size / X0;
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!"); static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
constexpr index_t Y0 = num_warps; static constexpr index_t Y0 = num_warps;
// YPerWarp = YPerTile / Y0; // YPerWarp = YPerTile / Y0;
// Y2 = YPerWarp / Y1; // Y2 = YPerWarp / Y1;
constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!"); static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!");
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile"); static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>, tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
...@@ -102,6 +100,17 @@ struct TileDistributionEncodingPattern2D<BlockSize, ...@@ -102,6 +100,17 @@ struct TileDistributionEncodingPattern2D<BlockSize,
sequence<1, 2>, sequence<1, 2>,
sequence<2, 1>>{}); sequence<2, 1>>{});
} }
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<1, 2>>{});
}
}; };
// Warp raked // Warp raked
...@@ -113,23 +122,24 @@ struct TileDistributionEncodingPattern2D<BlockSize, ...@@ -113,23 +122,24 @@ struct TileDistributionEncodingPattern2D<BlockSize,
tile_distribution_pattern::warp_raked> tile_distribution_pattern::warp_raked>
: public TileDistributionEcodingPattern : public TileDistributionEcodingPattern
{ {
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
constexpr index_t warp_size = get_warp_size();
constexpr index_t num_warps = BlockSize / get_warp_size();
constexpr index_t X1 = VecSize;
constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!"); static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
constexpr index_t Y0 = num_warps; static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!"); static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront static constexpr index_t Y0 = num_warps;
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile"); static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>, tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
...@@ -138,6 +148,17 @@ struct TileDistributionEncodingPattern2D<BlockSize, ...@@ -138,6 +148,17 @@ struct TileDistributionEncodingPattern2D<BlockSize,
sequence<1, 2>, sequence<1, 2>,
sequence<1, 1>>{}); sequence<1, 1>>{});
} }
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}; };
// Block raked // Block raked
...@@ -150,21 +171,21 @@ struct TileDistributionEncodingPattern2D<BlockSize, ...@@ -150,21 +171,21 @@ struct TileDistributionEncodingPattern2D<BlockSize,
: public TileDistributionEcodingPattern : public TileDistributionEcodingPattern
{ {
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
static constexpr index_t Y1 = num_warps;
static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{ {
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
constexpr index_t warp_size = get_warp_size();
constexpr index_t num_warps = BlockSize / get_warp_size();
constexpr index_t X1 = VecSize;
constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
constexpr index_t Y1 = num_warps;
static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
constexpr index_t Y0 = YPerTile / (Y2 * Y1);
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>, tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
...@@ -173,6 +194,17 @@ struct TileDistributionEncodingPattern2D<BlockSize, ...@@ -173,6 +194,17 @@ struct TileDistributionEncodingPattern2D<BlockSize,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 1>>{});
} }
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 0>>{});
}
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment