Commit 8f2b2d76 authored by Adam Osewski's avatar Adam Osewski
Browse files

Add more test cases for gemm splitk

parent 12cd4c72
...@@ -53,3 +53,165 @@ TYPED_TEST(TestGemmSplitK_KM_NK, SmallM) ...@@ -53,3 +53,165 @@ TYPED_TEST(TestGemmSplitK_KM_NK, SmallM)
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC); this->Run(M, N, K, M, StrideB, StrideC);
} }
TYPED_TEST(TestGemmSplitK_MK_KN, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
int N = 512;
int K = 320;
int StrideA = K;
int StrideB = N;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_MK_NK, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
int N = 512;
int K = 320;
int StrideA = K;
int StrideB = K;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_KN, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
int N = 512;
int K = 320;
int StrideB = N;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_NK, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
int N = 512;
int K = 320;
int StrideB = K;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_MK_KN, PaddK)
{
std::vector<int> Ms{127};
int N = 512;
int K = 437;
int StrideA = K;
int StrideB = N;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_MK_NK, PaddK)
{
std::vector<int> Ms{127};
int N = 512;
int K = 437;
int StrideA = K;
int StrideB = K;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_KN, PaddK)
{
std::vector<int> Ms{127};
int N = 512;
int K = 437;
int StrideB = N;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_NK, PaddK)
{
std::vector<int> Ms{127};
int N = 512;
int K = 437;
int StrideB = K;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_MK_KN, Regular)
{
std::vector<int> Ms{512};
int N = 512;
int K = 512;
int StrideA = K;
int StrideB = N;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_MK_NK, Regular)
{
std::vector<int> Ms{512};
int N = 512;
int K = 512;
int StrideA = K;
int StrideB = K;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_KN, Regular)
{
std::vector<int> Ms{512};
int N = 512;
int K = 512;
int StrideB = N;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_NK, Regular)
{
std::vector<int> Ms{512};
int N = 512;
int K = 512;
int StrideB = K;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment