Commit f2c1fa7f authored by ThomasNing's avatar ThomasNing
Browse files

Merge branch 'develop' of https://github.com/ROCm/composable_kernel into develop

parents 4658f2f6 0e5e29c4
...@@ -338,7 +338,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -338,7 +338,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
using ALayout = remove_cvref_t<typename Problem::ALayout>; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>); static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
......
...@@ -36,6 +36,8 @@ struct GemmPipelineProblemBase ...@@ -36,6 +36,8 @@ struct GemmPipelineProblemBase
static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK; static constexpr bool kPadK = Traits::kPadK;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default; static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize; static constexpr index_t VectorLoadSize = Traits::_VectorSize;
...@@ -173,6 +175,8 @@ struct UniversalGemmPipelineProblem ...@@ -173,6 +175,8 @@ struct UniversalGemmPipelineProblem
static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK; static constexpr bool kPadK = Traits::kPadK;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = Scheduler_; static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_; static constexpr auto TailNum = TailNum_;
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
namespace ck_tile { namespace ck_tile {
// UniversalGemm Policy template <typename Derived>
struct UniversalGemmPipelineAgBgCrPolicy struct UniversalGemmBasePolicy
{ {
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
...@@ -113,7 +113,7 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -113,7 +113,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{ {
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>; using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
using WG = typename BlockGemm::WarpGemm; using WG = typename BlockGemm::WarpGemm;
constexpr bool TransposeC = Problem::TransposeC; constexpr bool TransposeC = Problem::TransposeC;
...@@ -166,10 +166,116 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -166,10 +166,116 @@ struct UniversalGemmPipelineAgBgCrPolicy
} }
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Problem::TransposeC;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
// Tile: MPerBlock X KPerBlock
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
MPerBlock,
KPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// Tile: KPerBlock X MPerBlock
else
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// Tile: NPerBlock X KPerBlock
else
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{ {
using BlockGemm = decltype(GetBlockGemm<Problem>()); using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = BlockGemm::Traits::KPack; constexpr index_t KPack = BlockGemm::Traits::KPack;
return KPack; return KPack;
} }
...@@ -177,11 +283,43 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -177,11 +283,43 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{ {
using BlockGemm = decltype(GetBlockGemm<Problem>()); using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = BlockGemm::Traits::KPack; constexpr index_t KPack = BlockGemm::Traits::KPack;
return KPack; return KPack;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr auto a_lds_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
constexpr index_t smem_size_a = integer_least_multiple(
sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16);
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr auto b_lds_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t smem_size_b = integer_least_multiple(
sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16);
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
return smem_size_a + smem_size_b;
}
};
// UniversalGemm Policy
struct UniversalGemmPipelineAgBgCrPolicy
: public UniversalGemmBasePolicy<UniversalGemmPipelineAgBgCrPolicy>
{
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{ {
...@@ -421,133 +559,6 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -421,133 +559,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
#endif #endif
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0;
smem_size += smem_size_a + smem_size_b;
return smem_size;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
// Tile: MPerBlock X KPerBlock
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
MPerBlock,
KPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// Tile: KPerBlock X MPerBlock
else
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// Tile: NPerBlock X KPerBlock
else
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
......
...@@ -32,6 +32,7 @@ struct TileGemmTraits ...@@ -32,6 +32,7 @@ struct TileGemmTraits
template <bool kPadM_, template <bool kPadM_,
bool kPadN_, bool kPadN_,
bool kPadK_, bool kPadK_,
bool DoubleSmemBuffer_,
typename ALayout_, typename ALayout_,
typename BLayout_, typename BLayout_,
typename CLayout_, typename CLayout_,
...@@ -42,6 +43,8 @@ struct TileGemmUniversalTraits ...@@ -42,6 +43,8 @@ struct TileGemmUniversalTraits
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_; static constexpr bool kPadK = kPadK_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
using ALayout = ALayout_; using ALayout = ALayout_;
using BLayout = BLayout_; using BLayout = BLayout_;
using CLayout = CLayout_; using CLayout = CLayout_;
......
...@@ -17,22 +17,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, ...@@ -17,22 +17,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>; ck_tile::GemmPipelineScheduler::Interwave>;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>; using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>; using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem> std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>; >;
// clang-format on // clang-format on
......
...@@ -14,7 +14,32 @@ ...@@ -14,7 +14,32 @@
enum struct GemmPipelineType enum struct GemmPipelineType
{ {
Mem, Mem,
Comp CompV3,
CompV4
};
template <GemmPipelineType PT, typename Problem>
struct GemmPipelineTypeSelector;
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::Mem, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrMem<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrMem<Problem>;
};
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::CompV3, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>;
};
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrCompV4<Problem>;
}; };
template <typename Tuple> template <typename Tuple>
...@@ -36,8 +61,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -36,8 +61,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
// TODO: This should be parameterized in tests // TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 128; constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
...@@ -52,6 +77,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -52,6 +77,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool kPadN = PadN; constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK; constexpr bool kPadK = PadK;
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false;
// TODO: For now - but this should also be a test parameter // TODO: For now - but this should also be a test parameter
constexpr bool TransposeC = false; constexpr bool TransposeC = false;
...@@ -69,16 +96,20 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -69,16 +96,20 @@ class TestCkTileGemmPipeline : public ::testing::Test
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>; GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile:: using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>; kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem = using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = using BaseGemmPipeline =
std::conditional_t<PipelineType == GemmPipelineType::Mem, typename GemmPipelineTypeSelector<PipelineType, GemmPipelineProblem>::base_pipeline;
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
...@@ -99,12 +130,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -99,12 +130,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
has_hot_loop_v, has_hot_loop_v,
tail_number_v>; tail_number_v>;
using GemmPipeline = std::conditional_t< using GemmPipeline =
PipelineType == GemmPipelineType::Mem, typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>,
ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue< using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType, ck_tile::CShuffleEpilogueProblem<AccDataType,
...@@ -145,7 +172,7 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -145,7 +172,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
if(has_hot_loop) if(has_hot_loop)
{ {
if constexpr(PipelineType == GemmPipelineType::Comp) if constexpr(PipelineType == GemmPipelineType::CompV3)
{ {
if(tail_num == ck_tile::TailNumber::Full) if(tail_num == ck_tile::TailNumber::Full)
{ {
...@@ -235,6 +262,22 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -235,6 +262,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
} }
} }
} }
if constexpr(PipelineType == GemmPipelineType::CompV4)
{
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Three>{});
}
else
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Two>{});
}
}
} }
else else
{ {
...@@ -258,7 +301,19 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -258,7 +301,19 @@ class TestCkTileGemmPipeline : public ::testing::Test
public: public:
std::vector<int> k_batches_; std::vector<int> k_batches_;
void SetUp() override { k_batches_ = {1, 2}; } void SetUp() override
{
if constexpr(PipelineType == GemmPipelineType::CompV4)
{
// Only do k_batch = 1 when pipeline is CompV4
k_batches_ = {1};
}
else
{
// Otherwise, use k_batch = 1 and 2
k_batches_ = {1, 2};
}
}
template <bool PadM = true, bool PadN = true, bool PadK = true> template <bool PadM = true, bool PadN = true, bool PadK = true>
void Run(const int M, void Run(const int M,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment