Commit 406f71a6 authored by Adam Osewski's avatar Adam Osewski
Browse files

Debug: limit number of test cases to run;

parent c2bd0148
...@@ -14,12 +14,13 @@ using F16 = ck::half_t; ...@@ -14,12 +14,13 @@ using F16 = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>; // using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>; using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
const std::vector<int> KBATCH{1, 2, 4, 6, 8, 10, 12, 14, 16, 32, 64, 128}; // const std::vector<int> KBATCH{1, 2, 8, 32};
const std::vector<int> KBATCH{4};
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH)); // INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH)); INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH));
#include "test_grouped_gemm_ut_cases.inc" #include "test_grouped_gemm_ut_cases.inc"
#pragma once #pragma once
TEST_P(RCR_F16_F16_F16, TinyCases) // TEST_P(RRR_F16_F16_F16, TinyCases)
{ // {
const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; // const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
const std::vector<int> Ns(Ms.size(), 768); // const std::vector<int> Ns(Ms.size(), 4608);
const std::vector<int> Ks(Ms.size(), 4068); // const std::vector<int> Ks(Ms.size(), 384);
const std::vector<int> StrideAs(Ms.size(), 4608); // const std::vector<int> StrideAs(Ms.size(), 384);
const std::vector<int> StrideBs(Ms.size(), 4608); // const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 768); // const std::vector<int> StrideCs(Ms.size(), 4608);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); // this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
} // }
TEST_P(RRR_F16_F16_F16, TinyCases) // TEST_P(RRR_F16_F16_F16, SmallCases)
{ // {
const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; // const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
const std::vector<int> Ns(Ms.size(), 4608); // const std::vector<int> Ns(Ms.size(), 4608);
const std::vector<int> Ks(Ms.size(), 384); // const std::vector<int> Ks(Ms.size(), 384);
const std::vector<int> StrideAs(Ms.size(), 384); // const std::vector<int> StrideAs(Ms.size(), 384);
const std::vector<int> StrideBs(Ms.size(), 4608); // const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 4608); // const std::vector<int> StrideCs(Ms.size(), 4608);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); // this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
} // }
TEST_P(RCR_F16_F16_F16, SmallCases) // TEST_P(RRR_F16_F16_F16, MidCases)
{ // {
const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1}; // const std::vector<int> Ms{
const std::vector<int> Ns(Ms.size(), 768); // 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
const std::vector<int> Ks(Ms.size(), 4068); // const std::vector<int> Ns(Ms.size(), 4608);
const std::vector<int> StrideAs(Ms.size(), 4608); // const std::vector<int> Ks(Ms.size(), 384);
const std::vector<int> StrideBs(Ms.size(), 4608); // const std::vector<int> StrideAs(Ms.size(), 384);
const std::vector<int> StrideCs(Ms.size(), 768); // const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 4608);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); // this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
} // }
TEST_P(RRR_F16_F16_F16, SmallCases) // TEST_P(RCR_F16_F16_F16, TinyCases)
{ // {
const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1}; // const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
const std::vector<int> Ns(Ms.size(), 4608); // const std::vector<int> Ns(Ms.size(), 768);
const std::vector<int> Ks(Ms.size(), 384); // const std::vector<int> Ks(Ms.size(), 4608);
const std::vector<int> StrideAs(Ms.size(), 384); // const std::vector<int> StrideAs(Ms.size(), 4608);
const std::vector<int> StrideBs(Ms.size(), 4608); // const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 4608); // const std::vector<int> StrideCs(Ms.size(), 768);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); // this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
} // }
TEST_P(RCR_F16_F16_F16, MidCases) // TEST_P(RCR_F16_F16_F16, SmallCases)
{ // {
const std::vector<int> Ms{ // const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; // const std::vector<int> Ns(Ms.size(), 768);
const std::vector<int> Ns(Ms.size(), 768); // const std::vector<int> Ks(Ms.size(), 4608);
const std::vector<int> Ks(Ms.size(), 4068); // const std::vector<int> StrideAs(Ms.size(), 4608);
const std::vector<int> StrideAs(Ms.size(), 4608); // const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideBs(Ms.size(), 4608); // const std::vector<int> StrideCs(Ms.size(), 768);
const std::vector<int> StrideCs(Ms.size(), 768);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); // this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
} // }
TEST_P(RRR_F16_F16_F16, MidCases) TEST_P(RCR_F16_F16_F16, MidCases)
{ {
const std::vector<int> Ms{ const std::vector<int> Ms{
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; // 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
const std::vector<int> Ns(Ms.size(), 4608); 167};
const std::vector<int> Ks(Ms.size(), 384); // const std::vector<int> Ns(Ms.size(), 768);
const std::vector<int> StrideAs(Ms.size(), 384); // const std::vector<int> Ks(Ms.size(), 4608);
const std::vector<int> StrideBs(Ms.size(), 4608); // const std::vector<int> StrideAs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 4608); // const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 768);
const std::vector<int> Ns(Ms.size(), 256);
const std::vector<int> Ks(Ms.size(), 128);
const std::vector<int> StrideAs(Ms.size(), 128);
const std::vector<int> StrideBs(Ms.size(), 128);
const std::vector<int> StrideCs(Ms.size(), 256);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment