"docs/vscode:/vscode.git/clone" did not exist on "6133d98ff70eafad7b9f65da50a450a965d1957f"
Commit 7cbc1492 authored by Adam Osewski's avatar Adam Osewski
Browse files

Update unit-tests & disable mem pipeline.

parent e62670af
...@@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; ...@@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Intrawave>; ck_tile::GemmPipelineScheduler::Intrawave>;
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, // using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>; // ck_tile::GemmPipelineScheduler::Interwave>;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>; // using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>; using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, // std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, // std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, // std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem> // std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>; >;
// clang-format on // clang-format on
......
...@@ -10,22 +10,43 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM) ...@@ -10,22 +10,43 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM)
constexpr int K = 320; constexpr int K = 320;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K); {
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
else
this->Run(M, N, K);
}
} }
TYPED_TEST(TestCkTileGemmPipeline, MidLargeM) TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
{ {
std::vector<int> Ms{127, 255, 312, 799, 1573}; std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024; constexpr int N = 1024;
constexpr int K = 320; constexpr int K = 320;
constexpr int VecLoadSize = 8;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K); {
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
{
// TODO: Can we anyhow deduce used vector load size?
if(M % VecLoadSize == 0)
this->Run(M, N, K);
else
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
else
{
this->Run(M, N, K);
}
}
} }
TYPED_TEST(TestCkTileGemmPipeline, PaddK) TYPED_TEST(TestCkTileGemmPipeline, PaddK)
{ {
std::vector<int> Ms{127}; std::vector<int> Ms{128};
constexpr int N = 1024; constexpr int N = 1024;
constexpr int K = 432; constexpr int K = 432;
......
...@@ -51,6 +51,9 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -51,6 +51,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool kPadN = PadN; constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK; constexpr bool kPadK = PadK;
// TODO: For now - but this should also be a test parameter
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
// =============================================== // ===============================================
...@@ -65,14 +68,16 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -65,14 +68,16 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = std::conditional_t< using BaseGemmPipeline =
PipelineType == GemmPipelineType::Mem, std::conditional_t<PipelineType == GemmPipelineType::Mem,
ck_tile::BaseGemmPipelineAgBgCrMem< ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>, ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
ck_tile::BaseGemmPipelineAgBgCrCompV3<
ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
...@@ -84,26 +89,22 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -84,26 +89,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
using GemmPipeline = using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
std::conditional_t<PipelineType == GemmPipelineType::Mem, BDataType,
ck_tile::GemmPipelineAgBgCrMem< AccDataType,
ck_tile::UniversalGemmPipelineProblem<ADataType, GemmShape,
BDataType, GemmUniversalTraits,
AccDataType, Scheduler,
GemmShape, has_hot_loop_v,
Traits, tail_number_v>;
Scheduler,
has_hot_loop_v, using GemmPipeline = std::conditional_t<
tail_number_v>>, PipelineType == GemmPipelineType::Mem,
ck_tile::GemmPipelineAgBgCrCompV3< ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineProblem<ADataType, ck_tile::UniversalGemmPipelineAgBgCrPolicy>,
BDataType, ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem,
AccDataType, ck_tile::UniversalGemmPipelineAgBgCrPolicy>>;
GemmShape,
Traits,
Scheduler,
has_hot_loop_v,
tail_number_v>>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
...@@ -129,70 +130,94 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -129,70 +130,94 @@ class TestCkTileGemmPipeline : public ::testing::Test
if(has_hot_loop) if(has_hot_loop)
{ {
// Tail pipeline One to Seven if constexpr(PipelineType == GemmPipelineType::Comp)
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{ {
if(tail_num == ck_tile::TailNumber::Two) if(tail_num == ck_tile::TailNumber::Full)
{ {
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Two>{}); ck_tile::TailNumber::Full>{});
} }
} else
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{
if(tail_num == ck_tile::TailNumber::Three)
{ {
Run(ck_tile::bool_constant<true>{}, std::ostringstream err;
ck_tile::integral_constant<ck_tile::TailNumber, err << "For compute pipeline tail number should always be Full, but have \""
ck_tile::TailNumber::Three>{}); << tail_num << "\" which is not supported! PrefetchStages: "
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
} }
} }
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
if constexpr(PipelineType == GemmPipelineType::Mem)
{ {
if(tail_num == ck_tile::TailNumber::Four) // Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{ {
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Four>{}); ck_tile::TailNumber::One>{});
} }
} else if(tail_num == ck_tile::TailNumber::Full)
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{ {
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Five>{}); ck_tile::TailNumber::Full>{});
} }
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6) if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Six)
{ {
Run(ck_tile::bool_constant<true>{}, if(tail_num == ck_tile::TailNumber::Two)
ck_tile::integral_constant<ck_tile::TailNumber, {
ck_tile::TailNumber::Six>{}); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Two>{});
}
} }
} if constexpr(BaseGemmPipeline::PrefetchStages > 3)
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{ {
Run(ck_tile::bool_constant<true>{}, if(tail_num == ck_tile::TailNumber::Three)
ck_tile::integral_constant<ck_tile::TailNumber, {
ck_tile::TailNumber::Seven>{}); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Three>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{
if(tail_num == ck_tile::TailNumber::Four)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Four>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Five>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{
if(tail_num == ck_tile::TailNumber::Six)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Six>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Seven>{});
}
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment