"...composable_kernel.git" did not exist on "d140bdc9fa251d9519055c932e169e510d7f6785"
Commit 9eed0992 authored by Adam Osewski's avatar Adam Osewski
Browse files

Unit tests for multiple KBatch values.

parent 045bf6b6
......@@ -14,13 +14,21 @@ using F16 = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
// const std::vector<int> KBATCH{1, 2, 8, 32};
const std::vector<int> KBATCH{4};
using RRR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
// INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH));
const std::vector<int> KBATCH{1, 2, 3, 5, 8};
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_KN,
RRR_F16_F16_F16_LargeK,
testing::Values(32, 64));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_NK,
RCR_F16_F16_F16_LargeK,
testing::Values(32, 64));
#include "test_grouped_gemm_ut_cases.inc"
#pragma once
// TEST_P(RRR_F16_F16_F16, TinyCases)
// {
// const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
// const std::vector<int> Ns(Ms.size(), 4608);
// const std::vector<int> Ks(Ms.size(), 384);
// const std::vector<int> StrideAs(Ms.size(), 384);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 4608);
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
// TEST_P(RRR_F16_F16_F16, SmallCases)
// {
// const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
// const std::vector<int> Ns(Ms.size(), 4608);
// const std::vector<int> Ks(Ms.size(), 384);
// const std::vector<int> StrideAs(Ms.size(), 384);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 4608);
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
// TEST_P(RRR_F16_F16_F16, MidCases)
// {
// const std::vector<int> Ms{
// 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
// const std::vector<int> Ns(Ms.size(), 4608);
// const std::vector<int> Ks(Ms.size(), 384);
// const std::vector<int> StrideAs(Ms.size(), 384);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 4608);
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
// TEST_P(RCR_F16_F16_F16, TinyCases)
// {
// const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
// const std::vector<int> Ns(Ms.size(), 768);
// const std::vector<int> Ks(Ms.size(), 4608);
// const std::vector<int> StrideAs(Ms.size(), 4608);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 768);
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
// TEST_P(RCR_F16_F16_F16, SmallCases)
// {
// const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
// const std::vector<int> Ns(Ms.size(), 768);
// const std::vector<int> Ks(Ms.size(), 4608);
// const std::vector<int> StrideAs(Ms.size(), 4608);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 768);
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
TEST_P(RRR_F16_F16_F16, TinyCases)
{
const std::vector<int> Ms{0, 1};
const int N = 768;
const int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
const int N = 768;
const int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16, MidCases)
{
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
const int N = 768;
const int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16, Regular)
{
const std::vector<int> Ms{64, 128, 256};
const int N = 768;
const int K = 320;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
const int N = 136;
const int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, TinyCases)
{
const std::vector<int> Ms{0, 1};
const int N = 768;
const int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
const int N = 768;
const int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MidCases)
{
const std::vector<int> Ms{
// 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
167};
// const std::vector<int> Ns(Ms.size(), 768);
// const std::vector<int> Ks(Ms.size(), 4608);
// const std::vector<int> StrideAs(Ms.size(), 4608);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 768);
const std::vector<int> Ns(Ms.size(), 256);
const std::vector<int> Ks(Ms.size(), 128);
const std::vector<int> StrideAs(Ms.size(), 128);
const std::vector<int> StrideBs(Ms.size(), 128);
const std::vector<int> StrideCs(Ms.size(), 256);
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
const int N = 768;
const int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, Regular)
{
const std::vector<int> Ms{32, 64, 128, 256};
const int N = 768;
const int K = 320;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
const int N = 136;
const int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
{
const std::vector<int> Ms{188, 210};
const int N = 768;
const int K = 4096;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch)
{
const std::vector<int> Ms{188, 210};
const int N = 768;
const int K = 4096;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
......@@ -56,15 +56,6 @@ class TestGroupedGemm : public testing::TestWithParam<int>
const std::vector<int>& StrideCs,
int kbatch = 1)
{
std::cout << "Ms: [" << serialize_range(Ms) << "] "
<< "Ns: [" << serialize_range(Ns) << "] "
<< "Ks: [" << serialize_range(Ks) << "] "
<< "StrideAs: [" << serialize_range(StrideAs) << "] "
<< "StrideBs: [" << serialize_range(StrideBs) << "] "
<< "StrideCs: [" << serialize_range(StrideCs) << "] "
<< "kbatch: " << kbatch << std::endl;
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType,
EDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment