Commit 9eed0992 authored by Adam Osewski's avatar Adam Osewski
Browse files

Unit tests for multiple KBatch values.

parent 045bf6b6
...@@ -14,13 +14,21 @@ using F16 = ck::half_t; ...@@ -14,13 +14,21 @@ using F16 = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
// using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>; using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>; using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
// const std::vector<int> KBATCH{1, 2, 8, 32}; using RRR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
const std::vector<int> KBATCH{4}; using RCR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
// INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH)); const std::vector<int> KBATCH{1, 2, 3, 5, 8};
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH)); INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_KN,
RRR_F16_F16_F16_LargeK,
testing::Values(32, 64));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_NK,
RCR_F16_F16_F16_LargeK,
testing::Values(32, 64));
#include "test_grouped_gemm_ut_cases.inc" #include "test_grouped_gemm_ut_cases.inc"
#pragma once #pragma once
// TEST_P(RRR_F16_F16_F16, TinyCases) TEST_P(RRR_F16_F16_F16, TinyCases)
// { {
// const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; const std::vector<int> Ms{0, 1};
// const std::vector<int> Ns(Ms.size(), 4608); const int N = 768;
// const std::vector<int> Ks(Ms.size(), 384); const int K = 544;
// const std::vector<int> StrideAs(Ms.size(), 384);
// const std::vector<int> StrideBs(Ms.size(), 4608); const std::vector<int> Ns(Ms.size(), N);
// const std::vector<int> StrideCs(Ms.size(), 4608); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); const std::vector<int> StrideBs(Ms.size(), N);
// } const std::vector<int> StrideCs(Ms.size(), N);
// TEST_P(RRR_F16_F16_F16, SmallCases) this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// { }
// const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
// const std::vector<int> Ns(Ms.size(), 4608); TEST_P(RRR_F16_F16_F16, SmallCases)
// const std::vector<int> Ks(Ms.size(), 384); {
// const std::vector<int> StrideAs(Ms.size(), 384); const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
// const std::vector<int> StrideBs(Ms.size(), 4608); const int N = 768;
// const std::vector<int> StrideCs(Ms.size(), 4608); const int K = 544;
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); const std::vector<int> Ns(Ms.size(), N);
// } const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
// TEST_P(RRR_F16_F16_F16, MidCases) const std::vector<int> StrideBs(Ms.size(), N);
// { const std::vector<int> StrideCs(Ms.size(), N);
// const std::vector<int> Ms{
// 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// const std::vector<int> Ns(Ms.size(), 4608); }
// const std::vector<int> Ks(Ms.size(), 384);
// const std::vector<int> StrideAs(Ms.size(), 384); TEST_P(RRR_F16_F16_F16, MidCases)
// const std::vector<int> StrideBs(Ms.size(), 4608); {
// const std::vector<int> StrideCs(Ms.size(), 4608); const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
const int N = 768;
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); const int K = 544;
// }
const std::vector<int> Ns(Ms.size(), N);
// TEST_P(RCR_F16_F16_F16, TinyCases) const std::vector<int> Ks(Ms.size(), K);
// { const std::vector<int> StrideAs(Ms.size(), K);
// const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; const std::vector<int> StrideBs(Ms.size(), N);
// const std::vector<int> Ns(Ms.size(), 768); const std::vector<int> StrideCs(Ms.size(), N);
// const std::vector<int> Ks(Ms.size(), 4608);
// const std::vector<int> StrideAs(Ms.size(), 4608); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// const std::vector<int> StrideBs(Ms.size(), 4608); }
// const std::vector<int> StrideCs(Ms.size(), 768);
TEST_P(RRR_F16_F16_F16, Regular)
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); {
// } const std::vector<int> Ms{64, 128, 256};
const int N = 768;
// TEST_P(RCR_F16_F16_F16, SmallCases) const int K = 320;
// {
// const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1}; const std::vector<int> Ns(Ms.size(), N);
// const std::vector<int> Ns(Ms.size(), 768); const std::vector<int> Ks(Ms.size(), K);
// const std::vector<int> Ks(Ms.size(), 4608); const std::vector<int> StrideAs(Ms.size(), K);
// const std::vector<int> StrideAs(Ms.size(), 4608); const std::vector<int> StrideBs(Ms.size(), N);
// const std::vector<int> StrideBs(Ms.size(), 4608); const std::vector<int> StrideCs(Ms.size(), N);
// const std::vector<int> StrideCs(Ms.size(), 768);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); }
// }
TEST_P(RRR_F16_F16_F16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
const int N = 136;
const int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, TinyCases)
{
const std::vector<int> Ms{0, 1};
const int N = 768;
const int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
const int N = 768;
const int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MidCases) TEST_P(RCR_F16_F16_F16, MidCases)
{ {
const std::vector<int> Ms{ const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
// 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; const int N = 768;
167}; const int K = 544;
// const std::vector<int> Ns(Ms.size(), 768);
// const std::vector<int> Ks(Ms.size(), 4608); const std::vector<int> Ns(Ms.size(), N);
// const std::vector<int> StrideAs(Ms.size(), 4608); const std::vector<int> Ks(Ms.size(), K);
// const std::vector<int> StrideBs(Ms.size(), 4608); const std::vector<int> StrideAs(Ms.size(), K);
// const std::vector<int> StrideCs(Ms.size(), 768); const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
const std::vector<int> Ns(Ms.size(), 256);
const std::vector<int> Ks(Ms.size(), 128); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
const std::vector<int> StrideAs(Ms.size(), 128); }
const std::vector<int> StrideBs(Ms.size(), 128);
const std::vector<int> StrideCs(Ms.size(), 256); TEST_P(RCR_F16_F16_F16, Regular)
{
const std::vector<int> Ms{32, 64, 128, 256};
const int N = 768;
const int K = 320;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
const int N = 136;
const int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
{
const std::vector<int> Ms{188, 210};
const int N = 768;
const int K = 4096;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch)
{
const std::vector<int> Ms{188, 210};
const int N = 768;
const int K = 4096;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
} }
...@@ -56,15 +56,6 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -56,15 +56,6 @@ class TestGroupedGemm : public testing::TestWithParam<int>
const std::vector<int>& StrideCs, const std::vector<int>& StrideCs,
int kbatch = 1) int kbatch = 1)
{ {
std::cout << "Ms: [" << serialize_range(Ms) << "] "
<< "Ns: [" << serialize_range(Ns) << "] "
<< "Ks: [" << serialize_range(Ks) << "] "
<< "StrideAs: [" << serialize_range(StrideAs) << "] "
<< "StrideBs: [" << serialize_range(StrideBs) << "] "
<< "StrideCs: [" << serialize_range(StrideCs) << "] "
<< "kbatch: " << kbatch << std::endl;
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType, bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType, BDataType,
EDataType, EDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment