Unverified Commit 5311d1b3 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

changed test for grouped_gemm to be random (#959)


Co-authored-by: default avatarJing Zhang <jizha@amd.com>
parent aa46039f
...@@ -60,14 +60,13 @@ int main() ...@@ -60,14 +60,13 @@ int main()
int sum_of_m = 0; int sum_of_m = 0;
Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; const int group_count = 16;
int group_count = Ms.size();
for(int i = 0; i < group_count; ++i) for(int i = 0; i < group_count; ++i)
{ {
Ns.push_back(768); Ms.push_back(256 + 256 * i);
Ks.push_back(4608); Ns.push_back(128 + 128 * i);
Ks.push_back(128 + 64 * i);
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]); StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]); StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
......
...@@ -57,15 +57,13 @@ int main() ...@@ -57,15 +57,13 @@ int main()
int sum_of_m = 0; int sum_of_m = 0;
// Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; const int group_count = 16;
Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
int group_count = Ms.size();
for(int i = 0; i < group_count; ++i) for(int i = 0; i < group_count; ++i)
{ {
Ns.push_back(768); Ms.push_back(256 + 256 * i);
Ks.push_back(4608); Ns.push_back(128 + 128 * i);
Ks.push_back(128 + 64 * i);
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]); StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]); StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
......
...@@ -58,14 +58,13 @@ int main() ...@@ -58,14 +58,13 @@ int main()
int sum_of_m = 0; int sum_of_m = 0;
Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; const int group_count = 16;
int group_count = Ms.size();
for(int i = 0; i < group_count; ++i) for(int i = 0; i < group_count; ++i)
{ {
Ns.push_back(768); Ms.push_back(256 + 256 * i);
Ks.push_back(4608); Ns.push_back(128 + 128 * i);
Ks.push_back(128 + 64 * i);
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]); StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]); StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
......
...@@ -58,14 +58,13 @@ int main() ...@@ -58,14 +58,13 @@ int main()
int sum_of_m = 0; int sum_of_m = 0;
Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; const int group_count = 16;
int group_count = Ms.size();
for(int i = 0; i < group_count; ++i) for(int i = 0; i < group_count; ++i)
{ {
Ns.push_back(768); Ms.push_back(256 + 256 * i);
Ks.push_back(4608); Ns.push_back(128 + 128 * i);
Ks.push_back(128 + 64 * i);
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]); StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]); StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
......
...@@ -296,13 +296,11 @@ int main(int argc, char* argv[]) ...@@ -296,13 +296,11 @@ int main(int argc, char* argv[])
problem_size.group_count = 16; problem_size.group_count = 16;
problem_size.Ms = {
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ns.push_back(768); problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ks.push_back(4608); problem_size.Ns.push_back(128 + 128 * i);
problem_size.Ks.push_back(128 + 64 * i);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
......
...@@ -297,13 +297,11 @@ int main(int argc, char* argv[]) ...@@ -297,13 +297,11 @@ int main(int argc, char* argv[])
problem_size.group_count = 16; problem_size.group_count = 16;
problem_size.Ms = {
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ns.push_back(768); problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ks.push_back(4608); problem_size.Ns.push_back(128 + 128 * i);
problem_size.Ks.push_back(128 + 64 * i);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
......
...@@ -66,13 +66,11 @@ int main(int argc, char* argv[]) ...@@ -66,13 +66,11 @@ int main(int argc, char* argv[])
problem_size.group_count = 16; problem_size.group_count = 16;
problem_size.Ms = {
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ns.push_back(768); problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ks.push_back(4608); problem_size.Ns.push_back(128 + 128 * i);
problem_size.Ks.push_back(128 + 64 * i);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment