Commit 01f0831b authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix and add more unit-tests.

parent 1e188f76
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "test_grouped_gemm_util.hpp" #include "test_grouped_gemm_util.hpp"
class TestGGemmSplitKInterface_MKNK : public ::testing::Test class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test
{ {
protected: protected:
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -22,7 +22,8 @@ class TestGGemmSplitKInterface_MKNK : public ::testing::Test ...@@ -22,7 +22,8 @@ class TestGGemmSplitKInterface_MKNK : public ::testing::Test
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
template <ck::index_t KPerBlock, template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t KPerBlock,
ck::index_t K1, ck::index_t K1,
ck::index_t ABlockTransferSrcScalarPerVector, ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferSrcScalarPerVector,
...@@ -31,17 +32,17 @@ class TestGGemmSplitKInterface_MKNK : public ::testing::Test ...@@ -31,17 +32,17 @@ class TestGGemmSplitKInterface_MKNK : public ::testing::Test
ck::test::DeviceGroupedGemmSplitkInstanceWrapper<ALayout, ck::test::DeviceGroupedGemmSplitkInstanceWrapper<ALayout,
BLayout, BLayout,
ELayout, ELayout,
GemmDefault, GemmSpec,
KPerBlock, KPerBlock,
K1, K1,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
CDEBlockTransferScalarPerVector_NPerBlock>; CDEBlockTransferScalarPerVector_NPerBlock>;
using DefaultGGemmInstance = GGemmInstance<32, 8, 4, 8, 8>; using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 8>;
}; };
TEST_F(TestGGemmSplitKInterface_MKNK, TileSize) TEST_F(TestGGemmSplitKInterface_MKNKMN, TileSize)
{ {
std::vector<int> Ms{128, 256, 188, 512}; std::vector<int> Ms{128, 256, 188, 512};
int N = 256; int N = 256;
...@@ -56,13 +57,18 @@ TEST_F(TestGGemmSplitKInterface_MKNK, TileSize) ...@@ -56,13 +57,18 @@ TEST_F(TestGGemmSplitKInterface_MKNK, TileSize)
// M % MPerBlock // M % MPerBlock
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
Ms = std::vector<int>{256, 128, 128, 512};
Ns = std::vector<int>{256, 177, 128, 512}; Ns = std::vector<int>{256, 177, 128, 512};
// N % NPerBlock // N % NPerBlock
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
} }
TEST_F(TestGGemmSplitKInterface_MKNK, VectorLoadWidth) TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
{ {
static constexpr auto GemmMNKPadding =
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 4, 8, 8>;
std::vector<int> Ms{128, 256, 256, 512}; std::vector<int> Ms{128, 256, 256, 512};
int N = 256; int N = 256;
int K = 512; int K = 512;
...@@ -75,19 +81,19 @@ TEST_F(TestGGemmSplitKInterface_MKNK, VectorLoadWidth) ...@@ -75,19 +81,19 @@ TEST_F(TestGGemmSplitKInterface_MKNK, VectorLoadWidth)
// K % ABlockTransferSrcScalarPerVector // K % ABlockTransferSrcScalarPerVector
Ks = std::vector<int>{256, 177, 128, 512}; Ks = std::vector<int>{256, 177, 128, 512};
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
Ks = std::vector<int>{256, 164, 128, 512}; Ks = std::vector<int>{256, 164, 128, 512};
// K % BBlockTransferSrcScalarPerVector // K % BBlockTransferSrcScalarPerVector
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
Ks = std::vector<int>(4, 128); Ks = std::vector<int>(4, 128);
Ns = std::vector<int>{256, 153, 128, 512}; Ns = std::vector<int>{256, 127, 128, 512};
// N % CBlockTransferScalarPerVector_NWaveNPerXDL // N % CBlockTransferScalarPerVector_NWaveNPerXDL
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
} }
TEST_F(TestGGemmSplitKInterface_MKNK, KLoops) TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops)
{ {
std::vector<int> Ms{128, 256, 256, 512}; std::vector<int> Ms{128, 256, 256, 512};
int N = 256; int N = 256;
...@@ -110,3 +116,87 @@ TEST_F(TestGGemmSplitKInterface_MKNK, KLoops) ...@@ -110,3 +116,87 @@ TEST_F(TestGGemmSplitKInterface_MKNK, KLoops)
EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch), EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch),
std::runtime_error); std::runtime_error);
} }
class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
{
protected:
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ALayout = Col;
using BLayout = Row;
using ELayout = Col;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock>
using GGemmInstance =
ck::test::DeviceGroupedGemmSplitkInstanceWrapper<ALayout,
BLayout,
ELayout,
GemmSpec,
KPerBlock,
K1,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
CDEBlockTransferScalarPerVector_NPerBlock>;
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 4>;
};
TEST_F(TestGGemmSplitKInterface_KMKNNM, TileSize)
{
std::vector<int> Ms{128, 256, 188, 512};
int N = 256;
int K = 128;
std::vector<int> Ns(Ms.size(), N);
std::vector<int> Ks(Ms.size(), K);
std::vector<int> StrideAs(Ms.size(), K);
std::vector<int> StrideBs(Ms.size(), K);
std::vector<int> StrideCs(Ms.size(), N);
// M % MPerBlock
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
Ms = std::vector<int>{128, 256, 256, 512};
Ns = std::vector<int>{256, 177, 128, 512};
// N % NPerBlock
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
}
TEST_F(TestGGemmSplitKInterface_KMKNNM, VectorLoadWidth)
{
static constexpr auto GemmMNKPadding =
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 2, 8, 4>;
std::vector<int> Ms{128, 256, 256, 512};
int N = 256;
int K = 512;
std::vector<int> Ns(Ms.size(), N);
std::vector<int> Ks(Ms.size(), K);
std::vector<int> StrideAs(Ms.size(), K);
std::vector<int> StrideBs(Ms.size(), K);
std::vector<int> StrideCs(Ms.size(), N);
// M % ABlockTransferSrcScalarPerVector
Ms = std::vector<int>{256, 177, 128, 512};
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
Ms = std::vector<int>{128, 256, 256, 512};
Ns = std::vector<int>{256, 164, 128, 512};
// N % BBlockTransferSrcScalarPerVector
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
Ns = std::vector<int>{128, 256, 256, 512};
Ms = std::vector<int>{256, 130, 128, 512};
// M % CBlockTransferScalarPerVector_NWaveNPerXDL
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment