Commit 7ed791b8 authored by Adam Osewski's avatar Adam Osewski
Browse files

Update grouped_gemm multi d splitk example.

Enable passing in cmdline grouped gemm arguments.
parent 4841d991
...@@ -295,40 +295,73 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -295,40 +295,73 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
return pass; return pass;
} }
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
ProblemSize problem_size; ProblemSize problem_size;
ExecutionConfig config; ExecutionConfig config;
std::vector<ck::index_t> Ms{64}; if(argc < 11)
{
problem_size.group_count = Ms.size(); std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ms.push_back(Ms[i]); problem_size.Ms.push_back(Ms[i]);
problem_size.Ns.push_back(128); problem_size.Ns.push_back(250);
problem_size.Ks.push_back(128); problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]); problem_size.stride_Cs.push_back(problem_size.Ns[i]);
} }
if(argc == 5) config.do_verification = 1;
config.init_method = 3;
config.time_kernel = 0;
config.k_batch = 64;
std::cout
<< "Usage:\n"
<< "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "arg10: k_batch (> 0)\n"
<< "... setting default values." << std::endl;
}
else
{ {
config.do_verification = std::stoi(argv[1]); config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]); config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]); config.time_kernel = std::stoi(argv[3]);
config.k_batch = std::stoi(argv[4]); config.k_batch = std::stoi(argv[10]);
}
else problem_size.Ms = argToIntArray(argv[4]);
{ problem_size.Ns = argToIntArray(argv[5]);
printf("arg1: verification (0=no, 1=yes)\n"); problem_size.Ks = argToIntArray(argv[6]);
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n"); problem_size.stride_As = argToIntArray(argv[7]);
printf("arg4: k_batch (> 0)\n"); problem_size.stride_Bs = argToIntArray(argv[8]);
exit(0); problem_size.stride_Cs = argToIntArray(argv[9]);
problem_size.group_count = problem_size.Ms.size();
} }
return !run_grouped_gemm(problem_size, config); return !run_grouped_gemm(problem_size, config);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment