Commit d1715c0f authored by ThomasNing's avatar ThomasNing
Browse files

Fix the gtest compilation error

parent 987cc54d
...@@ -42,7 +42,9 @@ struct BaseGemmPipelineAgBgCrCompV4 ...@@ -42,7 +42,9 @@ struct BaseGemmPipelineAgBgCrCompV4
// the ping-pong buffer to grab memory from the global memory. While one LDS is grabbing the data // the ping-pong buffer to grab memory from the global memory. While one LDS is grabbing the data
// from global memory, the other will call the warps on running the MFMA matrix multiplication. When // from global memory, the other will call the warps on running the MFMA matrix multiplication. When
// the matrix is in bigger shape, it will keep the Warp always busy and cover the memory loading // the matrix is in bigger shape, it will keep the Warp always busy and cover the memory loading
// time. // time. It will have better performance comparing to the Compute Version 3 when they have the same
// block tile and better performance when you have M, N, K all > 8K even when the compute V3 block
// size is 2 times of the compute V4.
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy> template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy>
struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem> struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{ {
......
...@@ -17,7 +17,8 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, ...@@ -17,7 +17,8 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, // using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// ck_tile::GemmPipelineScheduler::Interwave>; // ck_tile::GemmPipelineScheduler::Interwave>;
// using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>; // using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>; // using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors. // TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
...@@ -25,16 +26,16 @@ using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp ...@@ -25,16 +26,16 @@ using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
// std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
// std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, // std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
// std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, // std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
// std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, // std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp> std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>
// std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem> // std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>; >;
// clang-format on // clang-format on
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
enum struct GemmPipelineType enum struct GemmPipelineType
{ {
Mem, Mem,
Comp CompV3,
CompV4
}; };
template <typename Tuple> template <typename Tuple>
...@@ -52,6 +53,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -52,6 +53,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool kPadN = PadN; constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK; constexpr bool kPadK = PadK;
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false;
// TODO: For now - but this should also be a test parameter // TODO: For now - but this should also be a test parameter
constexpr bool TransposeC = false; constexpr bool TransposeC = false;
...@@ -69,16 +72,24 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -69,16 +72,24 @@ class TestCkTileGemmPipeline : public ::testing::Test
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>; GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile:: using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>; kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem = using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = using BaseGemmPipeline = std::conditional_t<
std::conditional_t<PipelineType == GemmPipelineType::Mem, PipelineType == GemmPipelineType::Mem,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>, ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>; std::conditional_t<PipelineType == GemmPipelineType::CompV3,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV4<GemmPipelineProblem>>>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
...@@ -103,8 +114,11 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -103,8 +114,11 @@ class TestCkTileGemmPipeline : public ::testing::Test
PipelineType == GemmPipelineType::Mem, PipelineType == GemmPipelineType::Mem,
ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem, ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>, ck_tile::UniversalGemmPipelineAgBgCrPolicy>,
std::conditional_t<
PipelineType == GemmPipelineType::CompV3,
ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem, ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>>; ck_tile::UniversalGemmPipelineAgBgCrPolicy>,
ck_tile::GemmPipelineAgBgCrCompV4<UniversalGemmProblem>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue< using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType, ck_tile::CShuffleEpilogueProblem<AccDataType,
...@@ -145,7 +159,7 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -145,7 +159,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
if(has_hot_loop) if(has_hot_loop)
{ {
if constexpr(PipelineType == GemmPipelineType::Comp) if constexpr(PipelineType == GemmPipelineType::CompV3)
{ {
if(tail_num == ck_tile::TailNumber::Full) if(tail_num == ck_tile::TailNumber::Full)
{ {
...@@ -235,6 +249,22 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -235,6 +249,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
} }
} }
} }
if constexpr(PipelineType == GemmPipelineType::CompV4)
{
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Three>{});
}
else
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Two>{});
}
}
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment