Unverified Commit 9533a172 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into codegen-enable-hiprtc

parents c2cf0733 50ee4267
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -10,25 +10,35 @@ ...@@ -10,25 +10,35 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp" #include "test_grouped_gemm_util.hpp"
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>; template <typename Tuple>
using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>; class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple>
{
using RRR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>; };
using RCR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
// clang-format off
const std::vector<int> KBATCH{1, 2, 3, 5, 8}; using KernelTypes = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F16>,
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH)); std::tuple< Row, Col, Row, F16, F16, F16>,
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH)); std::tuple< Col, Row, Row, F16, F16, F16>,
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_KN, std::tuple< Col, Col, Row, F16, F16, F16>,
RRR_F16_F16_F16_LargeK, std::tuple< Row, Row, Row, BF16, BF16, BF16>,
testing::Values(32, 64)); std::tuple< Row, Col, Row, BF16, BF16, BF16>,
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_NK, std::tuple< Col, Row, Row, BF16, BF16, BF16>,
RCR_F16_F16_F16_LargeK, std::tuple< Row, Row, Row, BF16, I8, BF16>,
testing::Values(32, 64)); std::tuple< Row, Col, Row, BF16, I8, BF16>,
std::tuple< Row, Row, Row, F16, F8, F16>,
std::tuple< Row, Row, Row, F8, F16, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes);
#include "test_grouped_gemm_ut_cases.inc" #include "test_grouped_gemm_ut_cases.inc"
#pragma once #pragma once
TEST_P(RRR_F16_F16_F16, TinyCases) TYPED_TEST(TestGroupedGemm, TinyCases)
{ {
const std::vector<int> Ms{0, 1}; const std::vector<int> Ms{0, 1};
constexpr int N = 768; constexpr int N = 768;
...@@ -8,14 +8,11 @@ TEST_P(RRR_F16_F16_F16, TinyCases) ...@@ -8,14 +8,11 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RRR_F16_F16_F16, SmallCases) TYPED_TEST(TestGroupedGemm, SmallCases)
{ {
const std::vector<int> Ms{2, 1, 3, 4, 5, 0}; const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
constexpr int N = 768; constexpr int N = 768;
...@@ -23,14 +20,11 @@ TEST_P(RRR_F16_F16_F16, SmallCases) ...@@ -23,14 +20,11 @@ TEST_P(RRR_F16_F16_F16, SmallCases)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RRR_F16_F16_F16, MidCases) TYPED_TEST(TestGroupedGemm, MidCases)
{ {
const std::vector<int> Ms{167, 183, 177, 153, 139, 204}; const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768; constexpr int N = 768;
...@@ -38,14 +32,11 @@ TEST_P(RRR_F16_F16_F16, MidCases) ...@@ -38,14 +32,11 @@ TEST_P(RRR_F16_F16_F16, MidCases)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RRR_F16_F16_F16, Regular) TYPED_TEST(TestGroupedGemm, Regular)
{ {
const std::vector<int> Ms{64, 128, 256}; const std::vector<int> Ms{64, 128, 256};
constexpr int N = 768; constexpr int N = 768;
...@@ -53,14 +44,11 @@ TEST_P(RRR_F16_F16_F16, Regular) ...@@ -53,14 +44,11 @@ TEST_P(RRR_F16_F16_F16, Regular)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RRR_F16_F16_F16, MNKPadded) TYPED_TEST(TestGroupedGemm, MNKPadded)
{ {
const std::vector<int> Ms{127, 150, 188, 210}; const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136; constexpr int N = 136;
...@@ -68,88 +56,11 @@ TEST_P(RRR_F16_F16_F16, MNKPadded) ...@@ -68,88 +56,11 @@ TEST_P(RRR_F16_F16_F16, MNKPadded)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RCR_F16_F16_F16, TinyCases) TYPED_TEST(TestGroupedGemm, TestLargeKBatch)
{
const std::vector<int> Ms{0, 1};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MidCases)
{
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, Regular)
{
const std::vector<int> Ms{32, 64, 128, 256};
constexpr int N = 768;
constexpr int K = 320;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
{ {
const std::vector<int> Ms{188, 210}; const std::vector<int> Ms{188, 210};
constexpr int N = 768; constexpr int N = 768;
...@@ -157,24 +68,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch) ...@@ -157,24 +68,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch) this->k_batches_ = {32, 64};
{
const std::vector<int> Ms{188, 210};
constexpr int N = 768;
constexpr int K = 4096;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp" #include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
namespace ck { namespace ck {
namespace test { namespace test {
...@@ -40,7 +39,7 @@ std::string serialize_range(const Range& range) ...@@ -40,7 +39,7 @@ std::string serialize_range(const Range& range)
} }
template <typename Tuple> template <typename Tuple>
class TestGroupedGemm : public testing::TestWithParam<int> class TestGroupedGemm : public testing::Test
{ {
protected: protected:
using ALayout = std::tuple_element_t<0, Tuple>; using ALayout = std::tuple_element_t<0, Tuple>;
...@@ -50,23 +49,77 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -50,23 +49,77 @@ class TestGroupedGemm : public testing::TestWithParam<int>
using BDataType = std::tuple_element_t<4, Tuple>; using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>; using EDataType = std::tuple_element_t<5, Tuple>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
public: public:
static constexpr bool verify_ = true; static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization static constexpr int init_method_ = 1; // integer value initialization
static constexpr bool log_ = false; static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance static constexpr bool bench_ = false; // measure kernel performance
static constexpr int n_warmup_ = 0;
static constexpr int n_iter_ = 1;
std::vector<int> k_batches_;
void SetUp() override {} void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; }
private:
template <typename Layout>
void SetStrides(std::vector<int>& strides,
const std::vector<int>& rows,
const std::vector<int>& cols) const
{
if(std::is_same_v<Layout, Row>)
{
for(const auto c : cols)
{
strides.emplace_back(c);
}
}
else if(std::is_same_v<Layout, Col>)
{
for(const auto r : rows)
{
strides.emplace_back(r);
}
}
}
public:
void Run(const std::vector<int>& Ms, void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns, const std::vector<int>& Ns,
const std::vector<int>& Ks, const std::vector<int>& Ks,
const std::vector<int>& StrideAs, const std::vector<int>& StrideAs = {},
const std::vector<int>& StrideBs, const std::vector<int>& StrideBs = {},
const std::vector<int>& StrideCs, const std::vector<int>& StrideCs = {})
int kbatch = 1, {
int n_warmup = 1, std::vector<int> stride_as = StrideAs;
int n_iter = 10) std::vector<int> stride_bs = StrideBs;
std::vector<int> stride_cs = StrideCs;
if(stride_as.empty())
{
SetStrides<ALayout>(stride_as, Ms, Ks);
}
if(stride_bs.empty())
{
SetStrides<BLayout>(stride_bs, Ks, Ns);
}
if(stride_cs.empty())
{
SetStrides<ELayout>(stride_cs, Ms, Ns);
}
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_cs, k_batches_);
}
void RunSingle(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches)
{ {
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType, bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType, BDataType,
...@@ -84,61 +137,9 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -84,61 +137,9 @@ class TestGroupedGemm : public testing::TestWithParam<int>
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
kbatch, kbatches,
n_warmup, n_warmup_,
n_iter); n_iter_);
EXPECT_TRUE(pass);
}
};
template <typename Tuple>
class TestGroupedGemmTwoStage : public testing::TestWithParam<int>
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using ELayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
void SetUp() override {}
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{
bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
}; };
...@@ -263,7 +264,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -263,7 +264,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1) if(kbatch > 1)
{ {
ggemm_instance.SetKBatchSize(argument, kbatch); ggemm_instance.SetKBatchSize(&argument, kbatch);
} }
return ggemm_instance.IsSupportedArgument(argument); return ggemm_instance.IsSupportedArgument(argument);
...@@ -300,13 +301,13 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -300,13 +301,13 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1) if(kbatch > 1)
{ {
ggemm_instance.SetKBatchSize(argument, kbatch); ggemm_instance.SetKBatchSize(&argument, kbatch);
} }
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument)); EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker(); auto invoker = ggemm_instance.MakeInvoker();
DeviceMem gemm_desc_workspace(ggemm_instance.GetWorkSpaceSize(&argument)); DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument));
ggemm_instance.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false}); return invoker.Run(argument, StreamConfig{nullptr, false});
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment