Commit 241baec1 authored by Adam Osewski's avatar Adam Osewski
Browse files

Refactor universal gemm policy.

parent 77a38e02
......@@ -25,7 +25,7 @@ struct GemmKernel
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
......@@ -238,7 +238,7 @@ struct GemmKernel
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole wokrgroup.
// Run GEMM cooperatively by whole workgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
......
......@@ -4,7 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
......@@ -37,7 +37,7 @@ struct BaseGemmPipelineAgBgCrCompV3
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
......@@ -62,15 +62,14 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA;
static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA<Problem>();
static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>();
static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC<Problem>();
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
// Where is the right place for HasHotLoop and TailNum ???
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
......
......@@ -3,7 +3,7 @@
#pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/core.hpp"
namespace ck_tile {
......@@ -11,10 +11,10 @@ template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename TileGemmTraits_>
typename Traits_>
struct GemmPipelineProblemBase
{
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
......@@ -22,19 +22,19 @@ struct GemmPipelineProblemBase
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename GemmTraits::ALayout>;
using BLayout = remove_cvref_t<typename GemmTraits::BLayout>;
using CLayout = remove_cvref_t<typename GemmTraits::CLayout>;
using ALayout = remove_cvref_t<typename Traits::ALayout>;
using BLayout = remove_cvref_t<typename Traits::BLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr index_t VectorLoadSize = GemmTraits::_VectorSize;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = GemmTraits::kPadM;
static constexpr bool kPadN = GemmTraits::kPadN;
static constexpr bool kPadK = GemmTraits::kPadK;
static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
......@@ -128,27 +128,43 @@ template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename TileGemmTraits_>
typename Traits_>
using GemmPipelineProblem =
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, TileGemmTraits_>;
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, Traits_>;
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename TileGemmTraits_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
struct UniversalGemmPipelineProblem : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
TileGemmTraits_>
struct UniversalGemmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename Traits::ALayout>;
using BLayout = remove_cvref_t<typename Traits::BLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static constexpr bool TransposeC = Traits::TransposeC;
};
} // namespace ck_tile
......@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile {
......@@ -16,8 +17,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr bool TransposeC = true;
template <typename Problem, typename DataType, index_t MNPerBlock>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorLoadSize()
{
......@@ -25,6 +24,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
// Assume DataType is even!
if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0)
{
return (16 / sizeof(DataType));
......@@ -49,6 +50,95 @@ struct UniversalGemmPipelineAgBgCrPolicy
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
return GetVectorLoadSize<Problem, ADataType, MPerBlock>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
return GetVectorLoadSize<Problem, BDataType, NPerBlock>();
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
using WG = typename BlockGemm::WarpGemm;
constexpr bool TransposeC = Problem::TransposeC;
using CLayout = typename Problem::CLayout;
using CWarpDstr = typename WG::CWarpDstr;
// constexpr auto c_warp_x_lengths = CWarpDstr::get_lengths();
// using c_warp_hs_lengths = typename CWarpDstrEncoding::HsLengthss;
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
// static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
// c_warp_y_lengths.get(number<NDimY-1>{}));
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has just a single item in Mdim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
// static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
// c_warp_y_lengths.get(number<NDimY-1>{}));
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
......@@ -180,29 +270,28 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M1 = GetVectorSizeA<Problem>();
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t KPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
constexpr index_t elem_per_thr = MPerBlock * KPerBlock / BlockSize;
constexpr index_t K3 = elem_per_thr / M1; // # of loads per thr
constexpr index_t KPack = GetVectorSizeA<Problem>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * M0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * M0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
......@@ -217,6 +306,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
......@@ -228,15 +318,21 @@ struct UniversalGemmPipelineAgBgCrPolicy
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
// In RowMajor scenario we usually want to read whole KPerBlock tile dim
constexpr index_t K1 = GetVectorSizeA<Problem>();
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
// Coalesce reading for whole workgroup - workgroup raked pattern
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
......@@ -245,10 +341,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
sequence<1, 2>,
sequence<0, 1>>{});
}
// Coalesce reading for each wavefront - wavefront raked pattern
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
......@@ -263,22 +364,20 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N1 = GetVectorSizeB<Problem>();
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
constexpr index_t elem_per_thr = NPerBlock * KPerBlock / BlockSize;
static_assert(elem_per_thr % N1 == 0);
constexpr index_t K3 = elem_per_thr / N1;
constexpr index_t KPack = GetVectorSizeB<Problem>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
......@@ -286,6 +385,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
......@@ -300,6 +400,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
......@@ -312,7 +413,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K1 = GetVectorSizeB<Problem>();
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
......@@ -322,6 +423,9 @@ struct UniversalGemmPipelineAgBgCrPolicy
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
......@@ -336,6 +440,9 @@ struct UniversalGemmPipelineAgBgCrPolicy
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
......@@ -447,22 +554,21 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
AccDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
TransposeC>;
Problem::TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
}
};
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
......@@ -19,11 +20,34 @@ struct TileGemmTraits
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
// TODO this can't be hardcoded here! Should be in policy!
static constexpr int _VectorSize = 16;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
static constexpr bool TransposeC = false;
};
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
bool TransposeC_ = false>
struct TileGemmUniversalTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
static constexpr bool TransposeC = TransposeC_;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment