Commit 241baec1 authored by Adam Osewski's avatar Adam Osewski
Browse files

Refactor universal gemm policy.

parent 77a38e02
...@@ -25,7 +25,7 @@ struct GemmKernel ...@@ -25,7 +25,7 @@ struct GemmKernel
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>; using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>; // Below type is actually accumulation data type - the output of block GEMM.
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>; using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) __host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
...@@ -238,7 +238,7 @@ struct GemmKernel ...@@ -238,7 +238,7 @@ struct GemmKernel
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole wokrgroup. // Run GEMM cooperatively by whole workgroup.
auto c_block_tile = auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
...@@ -37,7 +37,7 @@ struct BaseGemmPipelineAgBgCrCompV3 ...@@ -37,7 +37,7 @@ struct BaseGemmPipelineAgBgCrCompV3
// LocalPreFillStages: 1 // LocalPreFillStages: 1
// LocalPreFetchStages: 1 // LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1 // LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy> template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{ {
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>; using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
...@@ -62,15 +62,14 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -62,15 +62,14 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA; static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA<Problem>();
static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>();
static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC<Problem>();
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK; static constexpr bool kPadK = Problem::kPadK;
// Where is the right place for HasHotLoop and TailNum ???
static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum; static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto Scheduler = Problem::Scheduler;
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#pragma once #pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -11,10 +11,10 @@ template <typename ADataType_, ...@@ -11,10 +11,10 @@ template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename TileGemmTraits_> typename Traits_>
struct GemmPipelineProblemBase struct GemmPipelineProblemBase
{ {
using GemmTraits = remove_cvref_t<TileGemmTraits_>; using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>; using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>; using BDataType = remove_cvref_t<BDataType_>;
...@@ -22,19 +22,19 @@ struct GemmPipelineProblemBase ...@@ -22,19 +22,19 @@ struct GemmPipelineProblemBase
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename GemmTraits::ALayout>; using ALayout = remove_cvref_t<typename Traits::ALayout>;
using BLayout = remove_cvref_t<typename GemmTraits::BLayout>; using BLayout = remove_cvref_t<typename Traits::BLayout>;
using CLayout = remove_cvref_t<typename GemmTraits::CLayout>; using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr index_t VectorLoadSize = GemmTraits::_VectorSize; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = GemmTraits::kPadM; static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = GemmTraits::kPadN; static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = GemmTraits::kPadK; static constexpr bool kPadK = Traits::kPadK;
static constexpr auto Scheduler = GemmPipelineScheduler::Default; static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{ {
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
...@@ -128,27 +128,43 @@ template <typename ADataType_, ...@@ -128,27 +128,43 @@ template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename TileGemmTraits_> typename Traits_>
using GemmPipelineProblem = using GemmPipelineProblem =
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, TileGemmTraits_>; GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, Traits_>;
template <typename ADataType_, template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename TileGemmTraits_, typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true, bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full> TailNumber TailNum_ = TailNumber::Full>
struct UniversalGemmPipelineProblem : public GemmPipelineProblemBase<ADataType_, struct UniversalGemmPipelineProblem
BDataType_,
CDataType_,
BlockGemmShape_,
TileGemmTraits_>
{ {
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename Traits::ALayout>;
using BLayout = remove_cvref_t<typename Traits::BLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr auto Scheduler = Scheduler_; static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_; static constexpr auto TailNum = TailNum_;
static constexpr bool TransposeC = Traits::TransposeC;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -16,8 +17,6 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -16,8 +17,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{}; static constexpr auto I2 = number<2>{};
static constexpr bool TransposeC = true;
template <typename Problem, typename DataType, index_t MNPerBlock> template <typename Problem, typename DataType, index_t MNPerBlock>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorLoadSize() CK_TILE_HOST_DEVICE static constexpr auto GetVectorLoadSize()
{ {
...@@ -25,6 +24,8 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -25,6 +24,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
// Assume DataType is even!
if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0) if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0)
{ {
return (16 / sizeof(DataType)); return (16 / sizeof(DataType));
...@@ -49,6 +50,95 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -49,6 +50,95 @@ struct UniversalGemmPipelineAgBgCrPolicy
} }
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
return GetVectorLoadSize<Problem, ADataType, MPerBlock>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
return GetVectorLoadSize<Problem, BDataType, NPerBlock>();
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
using WG = typename BlockGemm::WarpGemm;
constexpr bool TransposeC = Problem::TransposeC;
using CLayout = typename Problem::CLayout;
using CWarpDstr = typename WG::CWarpDstr;
// constexpr auto c_warp_x_lengths = CWarpDstr::get_lengths();
// using c_warp_hs_lengths = typename CWarpDstrEncoding::HsLengthss;
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
// static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
// c_warp_y_lengths.get(number<NDimY-1>{}));
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has just a single item in Mdim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
// static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
// c_warp_y_lengths.get(number<NDimY-1>{}));
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{ {
...@@ -180,29 +270,28 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -180,29 +270,28 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{ {
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); constexpr index_t M1 = GetVectorSizeA<Problem>();
constexpr index_t M0 = MPerBlock / M1; constexpr index_t M0 = MPerBlock / M1;
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; constexpr index_t elem_per_thr = MPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % M1 == 0); constexpr index_t K3 = elem_per_thr / M1; // # of loads per thr
constexpr index_t K3 = total_pixels / M1; constexpr index_t KPack = GetVectorSizeA<Problem>();
constexpr index_t KPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
static_assert(KPack % K3 == 0); static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3; constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * M0) == 0) if constexpr(get_warp_size() % (K2 * M0) == 0)
{ {
constexpr index_t K1 = get_warp_size() / (K2 * M0); constexpr index_t K1 = get_warp_size() / (K2 * M0);
constexpr index_t K0 = BlockSize / get_warp_size(); constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3); static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>, tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
...@@ -217,6 +306,7 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -217,6 +306,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr index_t K2_m = K2 / K1; constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1; constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3); static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>, tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
...@@ -228,15 +318,21 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -228,15 +318,21 @@ struct UniversalGemmPipelineAgBgCrPolicy
} }
else else
{ {
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); // In RowMajor scenario we usually want to read whole KPerBlock tile dim
constexpr index_t K1 = GetVectorSizeA<Problem>();
constexpr index_t K0 = KPerBlock / K1; constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t M2 = get_warp_size() / K0;
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
// Coalesce reading for whole workgroup - workgroup raked pattern
if constexpr(get_warp_size() % (M2 * K0) == 0) if constexpr(get_warp_size() % (M2 * K0) == 0)
{ {
constexpr index_t M1 = BlockSize / get_warp_size(); constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1); constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
...@@ -245,10 +341,15 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -245,10 +341,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 1>>{});
} }
// Coalesce reading for each wavefront - wavefront raked pattern
else else
{ {
constexpr index_t M0 = BlockSize / get_warp_size(); constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0); constexpr index_t M1 = MPerBlock / (M2 * M0);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
...@@ -263,22 +364,20 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -263,22 +364,20 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{ {
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{ {
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); constexpr index_t N1 = GetVectorSizeB<Problem>();
constexpr index_t N0 = NPerBlock / N1; constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; constexpr index_t elem_per_thr = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0); static_assert(elem_per_thr % N1 == 0);
constexpr index_t K3 = total_pixels / N1; constexpr index_t K3 = elem_per_thr / N1;
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>(); constexpr index_t KPack = GetVectorSizeB<Problem>();
static_assert(KPack % K3 == 0); static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3; constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0) if constexpr(get_warp_size() % (K2 * N0) == 0)
...@@ -286,6 +385,7 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -286,6 +385,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr index_t K1 = get_warp_size() / (K2 * N0); constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size(); constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3); static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>, tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
...@@ -300,6 +400,7 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -300,6 +400,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr index_t K2_m = K2 / K1; constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1; constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3); static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>, tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
...@@ -312,7 +413,7 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -312,7 +413,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
else else
{ {
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); constexpr index_t K1 = GetVectorSizeB<Problem>();
constexpr index_t K0 = KPerBlock / K1; constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0; constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks // coalesce reading for each blocks
...@@ -322,6 +423,9 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -322,6 +423,9 @@ struct UniversalGemmPipelineAgBgCrPolicy
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1); constexpr index_t N0 = NPerBlock / (N2 * N1);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
...@@ -336,6 +440,9 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -336,6 +440,9 @@ struct UniversalGemmPipelineAgBgCrPolicy
{ {
constexpr index_t N0 = BlockSize / get_warp_size(); constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0); constexpr index_t N1 = NPerBlock / (N2 * N0);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
...@@ -447,22 +554,21 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -447,22 +554,21 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType, using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType, typename Problem::BDataType,
AccDataType, typename Problem::CDataType,
WarpTile::at(I0), WarpTile::at(I0),
WarpTile::at(I1), WarpTile::at(I1),
WarpTile::at(I2), WarpTile::at(I2),
TransposeC>; Problem::TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType, using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType, typename Problem::BDataType,
typename Problem::CDataType, typename Problem::CDataType,
BlockWarps, BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{}; return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
} }
}; };
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -19,11 +20,34 @@ struct TileGemmTraits ...@@ -19,11 +20,34 @@ struct TileGemmTraits
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_; static constexpr bool kPadK = kPadK_;
// TODO this can't be hardcoded here! Should be in policy!
static constexpr int _VectorSize = 16; static constexpr int _VectorSize = 16;
using ALayout = ALayout_; using ALayout = ALayout_;
using BLayout = BLayout_; using BLayout = BLayout_;
using CLayout = CLayout_; using CLayout = CLayout_;
static constexpr bool TransposeC = false;
};
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
bool TransposeC_ = false>
struct TileGemmUniversalTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
static constexpr bool TransposeC = TransposeC_;
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment