Commit 406f71a6 authored by Adam Osewski's avatar Adam Osewski
Browse files

Debug: limit number of test cases to run;

parent c2bd0148
......@@ -14,12 +14,13 @@ using F16 = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
// using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
const std::vector<int> KBATCH{1, 2, 4, 6, 8, 10, 12, 14, 16, 32, 64, 128};
// const std::vector<int> KBATCH{1, 2, 8, 32};
const std::vector<int> KBATCH{4};
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH));
// INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH));
#include "test_grouped_gemm_ut_cases.inc"
#pragma once
TEST_P(RCR_F16_F16_F16, TinyCases)
{
const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
const std::vector<int> Ns(Ms.size(), 768);
const std::vector<int> Ks(Ms.size(), 4068);
const std::vector<int> StrideAs(Ms.size(), 4608);
const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 768);
// TEST_P(RRR_F16_F16_F16, TinyCases)
// {
// const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
// const std::vector<int> Ns(Ms.size(), 4608);
// const std::vector<int> Ks(Ms.size(), 384);
// const std::vector<int> StrideAs(Ms.size(), 384);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 4608);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
TEST_P(RRR_F16_F16_F16, TinyCases)
{
const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
const std::vector<int> Ns(Ms.size(), 4608);
const std::vector<int> Ks(Ms.size(), 384);
const std::vector<int> StrideAs(Ms.size(), 384);
const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 4608);
// TEST_P(RRR_F16_F16_F16, SmallCases)
// {
// const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
// const std::vector<int> Ns(Ms.size(), 4608);
// const std::vector<int> Ks(Ms.size(), 384);
// const std::vector<int> StrideAs(Ms.size(), 384);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 4608);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
TEST_P(RCR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
const std::vector<int> Ns(Ms.size(), 768);
const std::vector<int> Ks(Ms.size(), 4068);
const std::vector<int> StrideAs(Ms.size(), 4608);
const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 768);
// TEST_P(RRR_F16_F16_F16, MidCases)
// {
// const std::vector<int> Ms{
// 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
// const std::vector<int> Ns(Ms.size(), 4608);
// const std::vector<int> Ks(Ms.size(), 384);
// const std::vector<int> StrideAs(Ms.size(), 384);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 4608);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
TEST_P(RRR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
const std::vector<int> Ns(Ms.size(), 4608);
const std::vector<int> Ks(Ms.size(), 384);
const std::vector<int> StrideAs(Ms.size(), 384);
const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 4608);
// TEST_P(RCR_F16_F16_F16, TinyCases)
// {
// const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
// const std::vector<int> Ns(Ms.size(), 768);
// const std::vector<int> Ks(Ms.size(), 4608);
// const std::vector<int> StrideAs(Ms.size(), 4608);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 768);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
TEST_P(RCR_F16_F16_F16, MidCases)
{
const std::vector<int> Ms{
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
const std::vector<int> Ns(Ms.size(), 768);
const std::vector<int> Ks(Ms.size(), 4068);
const std::vector<int> StrideAs(Ms.size(), 4608);
const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 768);
// TEST_P(RCR_F16_F16_F16, SmallCases)
// {
// const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
// const std::vector<int> Ns(Ms.size(), 768);
// const std::vector<int> Ks(Ms.size(), 4608);
// const std::vector<int> StrideAs(Ms.size(), 4608);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 768);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
TEST_P(RRR_F16_F16_F16, MidCases)
TEST_P(RCR_F16_F16_F16, MidCases)
{
const std::vector<int> Ms{
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
const std::vector<int> Ns(Ms.size(), 4608);
const std::vector<int> Ks(Ms.size(), 384);
const std::vector<int> StrideAs(Ms.size(), 384);
const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 4608);
// 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
167};
// const std::vector<int> Ns(Ms.size(), 768);
// const std::vector<int> Ks(Ms.size(), 4608);
// const std::vector<int> StrideAs(Ms.size(), 4608);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 768);
const std::vector<int> Ns(Ms.size(), 256);
const std::vector<int> Ks(Ms.size(), 128);
const std::vector<int> StrideAs(Ms.size(), 128);
const std::vector<int> StrideBs(Ms.size(), 128);
const std::vector<int> StrideCs(Ms.size(), 256);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment