Unverified Commit 3e178e7c authored by Muhammed  Emin Ozturk's avatar Muhammed Emin Ozturk Committed by GitHub
Browse files

Merge branch 'develop' into muozturk_bf16fp8_streamk

parents 27fb084f 0e5e29c4
...@@ -17,22 +17,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, ...@@ -17,22 +17,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>; ck_tile::GemmPipelineScheduler::Interwave>;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>; using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>; using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem> std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>; >;
// clang-format on // clang-format on
......
...@@ -14,7 +14,32 @@ ...@@ -14,7 +14,32 @@
enum struct GemmPipelineType enum struct GemmPipelineType
{ {
Mem, Mem,
Comp CompV3,
CompV4
};
template <GemmPipelineType PT, typename Problem>
struct GemmPipelineTypeSelector;
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::Mem, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrMem<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrMem<Problem>;
};
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::CompV3, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>;
};
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrCompV4<Problem>;
}; };
template <typename Tuple> template <typename Tuple>
...@@ -36,8 +61,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -36,8 +61,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
// TODO: This should be parameterized in tests // TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 128; constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
...@@ -52,6 +77,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -52,6 +77,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool kPadN = PadN; constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK; constexpr bool kPadK = PadK;
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false;
// TODO: For now - but this should also be a test parameter // TODO: For now - but this should also be a test parameter
constexpr bool TransposeC = false; constexpr bool TransposeC = false;
...@@ -69,16 +96,20 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -69,16 +96,20 @@ class TestCkTileGemmPipeline : public ::testing::Test
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>; GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile:: using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>; kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem = using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = using BaseGemmPipeline =
std::conditional_t<PipelineType == GemmPipelineType::Mem, typename GemmPipelineTypeSelector<PipelineType, GemmPipelineProblem>::base_pipeline;
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
...@@ -99,12 +130,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -99,12 +130,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
has_hot_loop_v, has_hot_loop_v,
tail_number_v>; tail_number_v>;
using GemmPipeline = std::conditional_t< using GemmPipeline =
PipelineType == GemmPipelineType::Mem, typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>,
ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue< using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType, ck_tile::CShuffleEpilogueProblem<AccDataType,
...@@ -145,7 +172,7 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -145,7 +172,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
if(has_hot_loop) if(has_hot_loop)
{ {
if constexpr(PipelineType == GemmPipelineType::Comp) if constexpr(PipelineType == GemmPipelineType::CompV3)
{ {
if(tail_num == ck_tile::TailNumber::Full) if(tail_num == ck_tile::TailNumber::Full)
{ {
...@@ -235,6 +262,22 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -235,6 +262,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
} }
} }
} }
if constexpr(PipelineType == GemmPipelineType::CompV4)
{
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Three>{});
}
else
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Two>{});
}
}
} }
else else
{ {
...@@ -258,7 +301,19 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -258,7 +301,19 @@ class TestCkTileGemmPipeline : public ::testing::Test
public: public:
std::vector<int> k_batches_; std::vector<int> k_batches_;
void SetUp() override { k_batches_ = {1, 2}; } void SetUp() override
{
if constexpr(PipelineType == GemmPipelineType::CompV4)
{
// Only do k_batch = 1 when pipeline is CompV4
k_batches_ = {1};
}
else
{
// Otherwise, use k_batch = 1 and 2
k_batches_ = {1, 2};
}
}
template <bool PadM = true, bool PadN = true, bool PadK = true> template <bool PadM = true, bool PadN = true, bool PadK = true>
void Run(const int M, void Run(const int M,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment