Unverified Commit 228bc13a authored by M.Emin Ozturk's avatar M.Emin Ozturk Committed by GitHub
Browse files

Merge branch 'develop' into gemm_bf16_sk_muozturk

parents 39bc3b83 e758d006
...@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; ...@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("a_layout", "R", "A tensor data layout - Row by default") arg_parser.insert("Ms", "", "M dimensions - empty by default.")
.insert("b_layout", "R", "B tensor data layout - Row by default") .insert("Ns", "", "N dimensions - empty by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default") .insert("Ks", "", "K dimensions - empty by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU") .insert("stride_As", "", "Tensor A strides - it is empty by default.")
.insert("warmup", "10", "number of iterations before benchmark the kernel") .insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
.insert("group_count", "16", "group count"); .insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "R", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "16", "group count.");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
......
...@@ -53,26 +53,34 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -53,26 +53,34 @@ int run_grouped_gemm_example_with_layouts(int argc,
return -1; return -1;
}; };
auto valid_input_data = [&](int group_count, const auto&... args) {
return !(args.empty() || ...) && group_count == (args.size() == ...);
};
const int group_count = arg_parser.get_int("group_count"); const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat"); const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup"); const int warmup = arg_parser.get_int("warmup");
std::vector<ck_tile::index_t> Ms; std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns; std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks; std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> stride_As; std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs; std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs; std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
for(int i = 0; i < group_count; i++) if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
{ {
Ms.push_back(256 + 256 * i); std::cout << "Please check the input data. Default values will be used." << std::endl;
Ns.push_back(128 + 128 * i); for(int i = 0; i < group_count; i++)
Ks.push_back(128 + 64 * i); {
Ms.push_back(256 + 256 * i);
Ns.push_back(128 + 128 * i);
Ks.push_back(128 + 64 * i);
stride_As.push_back(Ks[i]); stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]); stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]); stride_Cs.push_back(Ns[i]);
}
} }
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors; std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
......
...@@ -15,11 +15,14 @@ ...@@ -15,11 +15,14 @@
namespace ck_tile { namespace ck_tile {
/* /*
* a host side utility, arg parser for * a host side utility, arg parser for, either
* -[key0]=[value0] -[key1]=[value1] ... * -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ...
*/ */
class ArgParser class ArgParser
{ {
public: public:
class Arg class Arg
{ {
...@@ -187,6 +190,45 @@ class ArgParser ...@@ -187,6 +190,45 @@ class ArgParser
return value; return value;
} }
std::vector<std::string> get_string_vec(const std::string& name,
const std::string& delimiter = ",") const
{
if(get_str(name).empty())
{
return {};
}
std::string s = get_str(name);
std::vector<std::string> tokens;
size_t pos = 0;
std::string token;
while((pos = s.find(delimiter)) != std::string::npos)
{
token = s.substr(0, pos);
tokens.push_back(token);
s.erase(0, pos + delimiter.length());
}
tokens.push_back(s);
return tokens;
}
std::vector<int> get_int_vec(const std::string& name, const std::string& delimiter = ",") const
{
if(get_str(name).empty())
{
return {};
}
const std::vector<std::string> args = get_string_vec(name, delimiter);
std::vector<int> tokens;
tokens.reserve(static_cast<int>(args.size()));
for(const std::string& token : args)
{
int value = atoi(token.c_str());
tokens.push_back(value);
}
return tokens;
}
private: private:
std::unordered_map<std::string, Arg> input_map; std::unordered_map<std::string, Arg> input_map;
std::vector<std::string> keys; std::vector<std::string> keys;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment