Commit 0c5f9438 authored by Adam Osewski's avatar Adam Osewski
Browse files

Constepxr everything!

parent 01f0831b
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
TYPED_TEST(TestGemmSplitK_MK_KN, SmallM) TYPED_TEST(TestGemmSplitK_MK_KN, SmallM)
{ {
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6}; std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
int N = 512; constexpr int N = 512;
int K = 320; constexpr int K = 320;
int StrideA = K; constexpr int StrideA = K;
int StrideB = N; constexpr int StrideB = N;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC); this->Run(M, N, K, StrideA, StrideB, StrideC);
...@@ -17,12 +17,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, SmallM) ...@@ -17,12 +17,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, SmallM)
TYPED_TEST(TestGemmSplitK_MK_NK, SmallM) TYPED_TEST(TestGemmSplitK_MK_NK, SmallM)
{ {
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6}; std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
int N = 512; constexpr int N = 512;
int K = 320; constexpr int K = 320;
int StrideA = K; constexpr int StrideA = K;
int StrideB = K; constexpr int StrideB = K;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC); this->Run(M, N, K, StrideA, StrideB, StrideC);
...@@ -31,11 +31,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, SmallM) ...@@ -31,11 +31,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, SmallM)
TYPED_TEST(TestGemmSplitK_KM_KN, SmallM) TYPED_TEST(TestGemmSplitK_KM_KN, SmallM)
{ {
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6}; std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
int N = 512; constexpr int N = 512;
int K = 320; constexpr int K = 320;
int StrideB = N; constexpr int StrideB = N;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC); this->Run(M, N, K, M, StrideB, StrideC);
...@@ -44,11 +44,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, SmallM) ...@@ -44,11 +44,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, SmallM)
TYPED_TEST(TestGemmSplitK_KM_NK, SmallM) TYPED_TEST(TestGemmSplitK_KM_NK, SmallM)
{ {
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6}; std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
int N = 512; constexpr int N = 512;
int K = 320; constexpr int K = 320;
int StrideB = K; constexpr int StrideB = K;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC); this->Run(M, N, K, M, StrideB, StrideC);
...@@ -57,12 +57,12 @@ TYPED_TEST(TestGemmSplitK_KM_NK, SmallM) ...@@ -57,12 +57,12 @@ TYPED_TEST(TestGemmSplitK_KM_NK, SmallM)
TYPED_TEST(TestGemmSplitK_MK_KN, MidLargeM) TYPED_TEST(TestGemmSplitK_MK_KN, MidLargeM)
{ {
std::vector<int> Ms{127, 255, 312, 799, 1573}; std::vector<int> Ms{127, 255, 312, 799, 1573};
int N = 512; constexpr int N = 512;
int K = 320; constexpr int K = 320;
int StrideA = K; constexpr int StrideA = K;
int StrideB = N; constexpr int StrideB = N;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC); this->Run(M, N, K, StrideA, StrideB, StrideC);
...@@ -71,12 +71,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, MidLargeM) ...@@ -71,12 +71,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, MidLargeM)
TYPED_TEST(TestGemmSplitK_MK_NK, MidLargeM) TYPED_TEST(TestGemmSplitK_MK_NK, MidLargeM)
{ {
std::vector<int> Ms{127, 255, 312, 799, 1573}; std::vector<int> Ms{127, 255, 312, 799, 1573};
int N = 512; constexpr int N = 512;
int K = 320; constexpr int K = 320;
int StrideA = K; constexpr int StrideA = K;
int StrideB = K; constexpr int StrideB = K;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC); this->Run(M, N, K, StrideA, StrideB, StrideC);
...@@ -85,11 +85,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, MidLargeM) ...@@ -85,11 +85,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, MidLargeM)
TYPED_TEST(TestGemmSplitK_KM_KN, MidLargeM) TYPED_TEST(TestGemmSplitK_KM_KN, MidLargeM)
{ {
std::vector<int> Ms{127, 255, 312, 799, 1573}; std::vector<int> Ms{127, 255, 312, 799, 1573};
int N = 512; constexpr int N = 512;
int K = 320; constexpr int K = 320;
int StrideB = N; constexpr int StrideB = N;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC); this->Run(M, N, K, M, StrideB, StrideC);
...@@ -98,11 +98,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, MidLargeM) ...@@ -98,11 +98,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, MidLargeM)
TYPED_TEST(TestGemmSplitK_KM_NK, MidLargeM) TYPED_TEST(TestGemmSplitK_KM_NK, MidLargeM)
{ {
std::vector<int> Ms{127, 255, 312, 799, 1573}; std::vector<int> Ms{127, 255, 312, 799, 1573};
int N = 512; constexpr int N = 512;
int K = 320; constexpr int K = 320;
int StrideB = K; constexpr int StrideB = K;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC); this->Run(M, N, K, M, StrideB, StrideC);
...@@ -111,12 +111,12 @@ TYPED_TEST(TestGemmSplitK_KM_NK, MidLargeM) ...@@ -111,12 +111,12 @@ TYPED_TEST(TestGemmSplitK_KM_NK, MidLargeM)
TYPED_TEST(TestGemmSplitK_MK_KN, PaddK) TYPED_TEST(TestGemmSplitK_MK_KN, PaddK)
{ {
std::vector<int> Ms{127}; std::vector<int> Ms{127};
int N = 512; constexpr int N = 512;
int K = 437; constexpr int K = 437;
int StrideA = K; constexpr int StrideA = K;
int StrideB = N; constexpr int StrideB = N;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC); this->Run(M, N, K, StrideA, StrideB, StrideC);
...@@ -125,12 +125,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, PaddK) ...@@ -125,12 +125,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, PaddK)
TYPED_TEST(TestGemmSplitK_MK_NK, PaddK) TYPED_TEST(TestGemmSplitK_MK_NK, PaddK)
{ {
std::vector<int> Ms{127}; std::vector<int> Ms{127};
int N = 512; constexpr int N = 512;
int K = 437; constexpr int K = 437;
int StrideA = K; constexpr int StrideA = K;
int StrideB = K; constexpr int StrideB = K;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC); this->Run(M, N, K, StrideA, StrideB, StrideC);
...@@ -139,11 +139,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, PaddK) ...@@ -139,11 +139,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, PaddK)
TYPED_TEST(TestGemmSplitK_KM_KN, PaddK) TYPED_TEST(TestGemmSplitK_KM_KN, PaddK)
{ {
std::vector<int> Ms{127}; std::vector<int> Ms{127};
int N = 512; constexpr int N = 512;
int K = 437; constexpr int K = 437;
int StrideB = N; constexpr int StrideB = N;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC); this->Run(M, N, K, M, StrideB, StrideC);
...@@ -152,11 +152,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, PaddK) ...@@ -152,11 +152,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, PaddK)
TYPED_TEST(TestGemmSplitK_KM_NK, PaddK) TYPED_TEST(TestGemmSplitK_KM_NK, PaddK)
{ {
std::vector<int> Ms{127}; std::vector<int> Ms{127};
int N = 512; constexpr int N = 512;
int K = 437; constexpr int K = 437;
int StrideB = K; constexpr int StrideB = K;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC); this->Run(M, N, K, M, StrideB, StrideC);
...@@ -165,12 +165,12 @@ TYPED_TEST(TestGemmSplitK_KM_NK, PaddK) ...@@ -165,12 +165,12 @@ TYPED_TEST(TestGemmSplitK_KM_NK, PaddK)
TYPED_TEST(TestGemmSplitK_MK_KN, Regular) TYPED_TEST(TestGemmSplitK_MK_KN, Regular)
{ {
std::vector<int> Ms{512}; std::vector<int> Ms{512};
int N = 512; constexpr int N = 512;
int K = 512; constexpr int K = 512;
int StrideA = K; constexpr int StrideA = K;
int StrideB = N; constexpr int StrideB = N;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC); this->Run(M, N, K, StrideA, StrideB, StrideC);
...@@ -179,12 +179,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, Regular) ...@@ -179,12 +179,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, Regular)
TYPED_TEST(TestGemmSplitK_MK_NK, Regular) TYPED_TEST(TestGemmSplitK_MK_NK, Regular)
{ {
std::vector<int> Ms{512}; std::vector<int> Ms{512};
int N = 512; constexpr int N = 512;
int K = 512; constexpr int K = 512;
int StrideA = K; constexpr int StrideA = K;
int StrideB = K; constexpr int StrideB = K;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC); this->Run(M, N, K, StrideA, StrideB, StrideC);
...@@ -193,11 +193,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, Regular) ...@@ -193,11 +193,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, Regular)
TYPED_TEST(TestGemmSplitK_KM_KN, Regular) TYPED_TEST(TestGemmSplitK_KM_KN, Regular)
{ {
std::vector<int> Ms{512}; std::vector<int> Ms{512};
int N = 512; constexpr int N = 512;
int K = 512; constexpr int K = 512;
int StrideB = N; constexpr int StrideB = N;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC); this->Run(M, N, K, M, StrideB, StrideC);
...@@ -206,11 +206,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, Regular) ...@@ -206,11 +206,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, Regular)
TYPED_TEST(TestGemmSplitK_KM_NK, Regular) TYPED_TEST(TestGemmSplitK_KM_NK, Regular)
{ {
std::vector<int> Ms{512}; std::vector<int> Ms{512};
int N = 512; constexpr int N = 512;
int K = 512; constexpr int K = 512;
int StrideB = K; constexpr int StrideB = K;
int StrideC = N; constexpr int StrideC = N;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC); this->Run(M, N, K, M, StrideB, StrideC);
......
...@@ -33,10 +33,10 @@ class TestGemmSplitK : public testing::Test ...@@ -33,10 +33,10 @@ class TestGemmSplitK : public testing::Test
using CDataType = std::tuple_element_t<4, Tuple>; using CDataType = std::tuple_element_t<4, Tuple>;
public: public:
bool verify_ = true; static constexpr bool verify_ = true;
int init_method_ = 1; // decimal value initialization static constexpr int init_method_ = 1; // decimal value initialization
bool log_ = false; static constexpr bool log_ = false;
bool bench_ = false; // measure kernel performance static constexpr bool bench_ = false; // measure kernel performance
std::vector<int> k_batches_; std::vector<int> k_batches_;
void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; } void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; }
......
...@@ -45,8 +45,8 @@ class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test ...@@ -45,8 +45,8 @@ class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test
TEST_F(TestGGemmSplitKInterface_MKNKMN, TileSize) TEST_F(TestGGemmSplitKInterface_MKNKMN, TileSize)
{ {
std::vector<int> Ms{128, 256, 188, 512}; std::vector<int> Ms{128, 256, 188, 512};
int N = 256; constexpr int N = 256;
int K = 128; constexpr int K = 128;
std::vector<int> Ns(Ms.size(), N); std::vector<int> Ns(Ms.size(), N);
std::vector<int> Ks(Ms.size(), K); std::vector<int> Ks(Ms.size(), K);
...@@ -70,8 +70,8 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth) ...@@ -70,8 +70,8 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 4, 8, 8>; using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 4, 8, 8>;
std::vector<int> Ms{128, 256, 256, 512}; std::vector<int> Ms{128, 256, 256, 512};
int N = 256; constexpr int N = 256;
int K = 512; constexpr int K = 512;
std::vector<int> Ns(Ms.size(), N); std::vector<int> Ns(Ms.size(), N);
std::vector<int> Ks(Ms.size(), K); std::vector<int> Ks(Ms.size(), K);
...@@ -96,9 +96,9 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth) ...@@ -96,9 +96,9 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops) TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops)
{ {
std::vector<int> Ms{128, 256, 256, 512}; std::vector<int> Ms{128, 256, 256, 512};
int N = 256; constexpr int N = 256;
int K = 128; constexpr int K = 128;
int kbatch = 4; constexpr int kbatch = 4;
std::vector<int> Ns(Ms.size(), N); std::vector<int> Ns(Ms.size(), N);
std::vector<int> Ks(Ms.size(), K); std::vector<int> Ks(Ms.size(), K);
...@@ -152,8 +152,8 @@ class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test ...@@ -152,8 +152,8 @@ class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
TEST_F(TestGGemmSplitKInterface_KMKNNM, TileSize) TEST_F(TestGGemmSplitKInterface_KMKNNM, TileSize)
{ {
std::vector<int> Ms{128, 256, 188, 512}; std::vector<int> Ms{128, 256, 188, 512};
int N = 256; constexpr int N = 256;
int K = 128; constexpr int K = 128;
std::vector<int> Ns(Ms.size(), N); std::vector<int> Ns(Ms.size(), N);
std::vector<int> Ks(Ms.size(), K); std::vector<int> Ks(Ms.size(), K);
...@@ -177,8 +177,8 @@ TEST_F(TestGGemmSplitKInterface_KMKNNM, VectorLoadWidth) ...@@ -177,8 +177,8 @@ TEST_F(TestGGemmSplitKInterface_KMKNNM, VectorLoadWidth)
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 2, 8, 4>; using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 2, 8, 4>;
std::vector<int> Ms{128, 256, 256, 512}; std::vector<int> Ms{128, 256, 256, 512};
int N = 256; constexpr int N = 256;
int K = 512; constexpr int K = 512;
std::vector<int> Ns(Ms.size(), N); std::vector<int> Ns(Ms.size(), N);
std::vector<int> Ks(Ms.size(), K); std::vector<int> Ks(Ms.size(), K);
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
TEST_P(RRR_F16_F16_F16, TinyCases) TEST_P(RRR_F16_F16_F16, TinyCases)
{ {
const std::vector<int> Ms{0, 1}; const std::vector<int> Ms{0, 1};
const int N = 768; constexpr int N = 768;
const int K = 544; constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
...@@ -18,8 +18,8 @@ TEST_P(RRR_F16_F16_F16, TinyCases) ...@@ -18,8 +18,8 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
TEST_P(RRR_F16_F16_F16, SmallCases) TEST_P(RRR_F16_F16_F16, SmallCases)
{ {
const std::vector<int> Ms{2, 1, 3, 4, 5, 0}; const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
const int N = 768; constexpr int N = 768;
const int K = 544; constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
...@@ -33,8 +33,8 @@ TEST_P(RRR_F16_F16_F16, SmallCases) ...@@ -33,8 +33,8 @@ TEST_P(RRR_F16_F16_F16, SmallCases)
TEST_P(RRR_F16_F16_F16, MidCases) TEST_P(RRR_F16_F16_F16, MidCases)
{ {
const std::vector<int> Ms{167, 183, 177, 153, 139, 204}; const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
const int N = 768; constexpr int N = 768;
const int K = 544; constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
...@@ -48,8 +48,8 @@ TEST_P(RRR_F16_F16_F16, MidCases) ...@@ -48,8 +48,8 @@ TEST_P(RRR_F16_F16_F16, MidCases)
TEST_P(RRR_F16_F16_F16, Regular) TEST_P(RRR_F16_F16_F16, Regular)
{ {
const std::vector<int> Ms{64, 128, 256}; const std::vector<int> Ms{64, 128, 256};
const int N = 768; constexpr int N = 768;
const int K = 320; constexpr int K = 320;
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
...@@ -63,8 +63,8 @@ TEST_P(RRR_F16_F16_F16, Regular) ...@@ -63,8 +63,8 @@ TEST_P(RRR_F16_F16_F16, Regular)
TEST_P(RRR_F16_F16_F16, MNKPadded) TEST_P(RRR_F16_F16_F16, MNKPadded)
{ {
const std::vector<int> Ms{127, 150, 188, 210}; const std::vector<int> Ms{127, 150, 188, 210};
const int N = 136; constexpr int N = 136;
const int K = 280; constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
...@@ -78,8 +78,8 @@ TEST_P(RRR_F16_F16_F16, MNKPadded) ...@@ -78,8 +78,8 @@ TEST_P(RRR_F16_F16_F16, MNKPadded)
TEST_P(RCR_F16_F16_F16, TinyCases) TEST_P(RCR_F16_F16_F16, TinyCases)
{ {
const std::vector<int> Ms{0, 1}; const std::vector<int> Ms{0, 1};
const int N = 768; constexpr int N = 768;
const int K = 544; constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
...@@ -92,8 +92,8 @@ TEST_P(RCR_F16_F16_F16, TinyCases) ...@@ -92,8 +92,8 @@ TEST_P(RCR_F16_F16_F16, TinyCases)
TEST_P(RCR_F16_F16_F16, SmallCases) TEST_P(RCR_F16_F16_F16, SmallCases)
{ {
const std::vector<int> Ms{2, 1, 3, 4, 5, 0}; const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
const int N = 768; constexpr int N = 768;
const int K = 544; constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
...@@ -107,8 +107,8 @@ TEST_P(RCR_F16_F16_F16, SmallCases) ...@@ -107,8 +107,8 @@ TEST_P(RCR_F16_F16_F16, SmallCases)
TEST_P(RCR_F16_F16_F16, MidCases) TEST_P(RCR_F16_F16_F16, MidCases)
{ {
const std::vector<int> Ms{167, 183, 177, 153, 139, 204}; const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
const int N = 768; constexpr int N = 768;
const int K = 544; constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
...@@ -122,8 +122,8 @@ TEST_P(RCR_F16_F16_F16, MidCases) ...@@ -122,8 +122,8 @@ TEST_P(RCR_F16_F16_F16, MidCases)
TEST_P(RCR_F16_F16_F16, Regular) TEST_P(RCR_F16_F16_F16, Regular)
{ {
const std::vector<int> Ms{32, 64, 128, 256}; const std::vector<int> Ms{32, 64, 128, 256};
const int N = 768; constexpr int N = 768;
const int K = 320; constexpr int K = 320;
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
...@@ -137,8 +137,8 @@ TEST_P(RCR_F16_F16_F16, Regular) ...@@ -137,8 +137,8 @@ TEST_P(RCR_F16_F16_F16, Regular)
TEST_P(RCR_F16_F16_F16, MNKPadded) TEST_P(RCR_F16_F16_F16, MNKPadded)
{ {
const std::vector<int> Ms{127, 150, 188, 210}; const std::vector<int> Ms{127, 150, 188, 210};
const int N = 136; constexpr int N = 136;
const int K = 280; constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
...@@ -152,8 +152,8 @@ TEST_P(RCR_F16_F16_F16, MNKPadded) ...@@ -152,8 +152,8 @@ TEST_P(RCR_F16_F16_F16, MNKPadded)
TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch) TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
{ {
const std::vector<int> Ms{188, 210}; const std::vector<int> Ms{188, 210};
const int N = 768; constexpr int N = 768;
const int K = 4096; constexpr int K = 4096;
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
...@@ -167,8 +167,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch) ...@@ -167,8 +167,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch) TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch)
{ {
const std::vector<int> Ms{188, 210}; const std::vector<int> Ms{188, 210};
const int N = 768; constexpr int N = 768;
const int K = 4096; constexpr int K = 4096;
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
......
...@@ -50,10 +50,10 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -50,10 +50,10 @@ class TestGroupedGemm : public testing::TestWithParam<int>
using EDataType = std::tuple_element_t<5, Tuple>; using EDataType = std::tuple_element_t<5, Tuple>;
public: public:
bool verify_ = true; static constexpr bool verify_ = true;
int init_method_ = 0; // decimal value initialization static constexpr int init_method_ = 1; // decimal value initialization
bool log_ = false; static constexpr bool log_ = false;
bool bench_ = false; // measure kernel performance static constexpr bool bench_ = false; // measure kernel performance
void SetUp() override {} void SetUp() override {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment