Commit d22c892c authored by Adam Osewski's avatar Adam Osewski
Browse files

Unit tests for MKNK ggemm interface.

parent c67ae342
add_custom_target(test_grouped_gemm) add_custom_target(test_grouped_gemm)
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp) add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp)
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp)
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <stdexcept>
#include <vector>
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "test_grouped_gemm_util.hpp"
class TestGGemmSplitKInterface_MKNK : public ::testing::Test
{
protected:
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
template <ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock>
using GGemmInstance =
ck::test::DeviceGroupedGemmSplitkInstanceWrapper<ALayout,
BLayout,
ELayout,
GemmDefault,
KPerBlock,
K1,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
CDEBlockTransferScalarPerVector_NPerBlock>;
using DefaultGGemmInstance = GGemmInstance<32, 8, 4, 8, 8>;
};
TEST_F(TestGGemmSplitKInterface_MKNK, TileSize)
{
std::vector<int> Ms{128, 256, 188, 512};
int N = 256;
int K = 128;
std::vector<int> Ns(Ms.size(), N);
std::vector<int> Ks(Ms.size(), K);
std::vector<int> StrideAs(Ms.size(), K);
std::vector<int> StrideBs(Ms.size(), K);
std::vector<int> StrideCs(Ms.size(), N);
// M % MPerBlock
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
Ns = std::vector<int>{256, 177, 128, 512};
// N % NPerBlock
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
}
TEST_F(TestGGemmSplitKInterface_MKNK, VectorLoadWidth)
{
std::vector<int> Ms{128, 256, 256, 512};
int N = 256;
int K = 512;
std::vector<int> Ns(Ms.size(), N);
std::vector<int> Ks(Ms.size(), K);
std::vector<int> StrideAs(Ms.size(), K);
std::vector<int> StrideBs(Ms.size(), K);
std::vector<int> StrideCs(Ms.size(), N);
// K % ABlockTransferSrcScalarPerVector
Ks = std::vector<int>{256, 177, 128, 512};
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
Ks = std::vector<int>{256, 164, 128, 512};
// K % BBlockTransferSrcScalarPerVector
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
Ks = std::vector<int>(4, 128);
Ns = std::vector<int>{256, 153, 128, 512};
// N % CBlockTransferScalarPerVector_NWaveNPerXDL
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
}
TEST_F(TestGGemmSplitKInterface_MKNK, KLoops)
{
std::vector<int> Ms{128, 256, 256, 512};
int N = 256;
int K = 128;
int kbatch = 4;
std::vector<int> Ns(Ms.size(), N);
std::vector<int> Ks(Ms.size(), K);
std::vector<int> StrideAs(Ms.size(), K);
std::vector<int> StrideBs(Ms.size(), K);
std::vector<int> StrideCs(Ms.size(), N);
// kloops % 2
Ks = std::vector<int>{256, 512, 320, 768};
EXPECT_FALSE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
// Not all gemms have same value for main_k0_block_loop!
Ks = std::vector<int>{256, 512, 512, 512};
EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch),
std::runtime_error);
}
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#include <array>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <tuple> #include <tuple>
...@@ -10,8 +11,16 @@ ...@@ -10,8 +11,16 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "include/ck/utility/data_type.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp" #include "profiler/profile_grouped_gemm_impl.hpp"
namespace ck { namespace ck {
...@@ -68,5 +77,173 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -68,5 +77,173 @@ class TestGroupedGemm : public testing::TestWithParam<int>
} }
}; };
template <typename ALayout,
typename BLayout,
typename ELayout,
tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector,
index_t CDEBlockTransferScalarPerVector_NPerBlock>
struct DeviceGroupedGemmSplitkInstanceWrapper
{
using F16 = half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough;
using EmptyTuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
template <ck::index_t N>
using I = ck::Number<N>;
using ABlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>;
using ABlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>;
using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>;
using BBlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>;
using BBlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>;
using BBlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<0>, I<1>>;
using DeviceGroupedGemmSplitKInstance =
tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle<
ALayout,
BLayout,
EmptyTuple,
ELayout,
F16,
F16,
F32,
F16,
EmptyTuple,
F16,
PassThrough,
PassThrough,
PassThrough,
GemmSpec,
1,
128,
128,
128,
KPerBlock,
K1,
K1,
32,
32,
4,
2,
S<1, 4, 32, 1>,
ABlockTransferThreadClusterArrageOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim::value,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1::value,
ABlockLdsAddExtraM::value,
S<1, 4, 32, 1>,
BBlockTransferThreadClusterArrageOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim::value,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1::value,
BBlockLdsAddExtraM::value,
1,
1,
S<1, 16, 1, 8>,
CDEBlockTransferScalarPerVector_NPerBlock>;
bool IsSupported(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(argument, kbatch);
}
return ggemm_instance.IsSupportedArgument(argument);
}
float Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(argument, kbatch);
}
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker();
DeviceMem gemm_desc_workspace(ggemm_instance.GetWorkSpaceSize(&argument));
ggemm_instance.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false});
}
};
} // namespace test } // namespace test
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment