"vscode:/vscode.git/clone" did not exist on "a109f4619cc8d74eb080832a2f45fa738b754d19"
Commit c20d0fc2 authored by Paul's avatar Paul
Browse files

Add read command

parent b726fc1a
add_executable(driver main.cpp) add_executable(driver main.cpp)
target_link_libraries(driver migraphx_cpu) target_link_libraries(driver migraphx_cpu migraphx_onnx migraphx_tf)
...@@ -57,17 +57,19 @@ struct argument_parser ...@@ -57,17 +57,19 @@ struct argument_parser
std::function<bool(argument_parser&, const std::vector<std::string>&)> action{}; std::function<bool(argument_parser&, const std::vector<std::string>&)> action{};
std::string type = ""; std::string type = "";
std::string help = ""; std::string help = "";
std::string metavar = "";
unsigned nargs = 1;
}; };
template <class T, class... Fs> template <class T, class... Fs>
void add(T& x, std::vector<std::string> flags, Fs... fs) void add(T& x, std::vector<std::string> flags, Fs... fs)
{ {
arguments.emplace_back(flags, [&](auto&&, const std::vector<std::string>& params) { arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) {
if(params.empty()) if(params.empty())
throw std::runtime_error("Flag with no value."); throw std::runtime_error("Flag with no value.");
x = value_parser<T>::apply(params.back()); x = value_parser<T>::apply(params.back());
return false; return false;
}); }});
argument& arg = arguments.back(); argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>(); arg.type = migraphx::get_type_name<T>();
...@@ -81,9 +83,17 @@ struct argument_parser ...@@ -81,9 +83,17 @@ struct argument_parser
argument& arg = arguments.back(); argument& arg = arguments.back();
arg.type = ""; arg.type = "";
arg.nargs = 0;
migraphx::each_args([&](auto f) { f(x, arg); }, fs...); migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
} }
static auto nargs(unsigned n = 1)
{
return [=](auto&&, auto& arg) {
arg.nargs = n;
};
}
template <class F> template <class F>
static auto write_action(F f) static auto write_action(F f)
{ {
...@@ -99,6 +109,7 @@ struct argument_parser ...@@ -99,6 +109,7 @@ struct argument_parser
static auto do_action(F f) static auto do_action(F f)
{ {
return [=](auto&, auto& arg) { return [=](auto&, auto& arg) {
arg.nargs = 0;
arg.action = [&, f](auto& self, const std::vector<std::string>&) { arg.action = [&, f](auto& self, const std::vector<std::string>&) {
f(self); f(self);
return true; return true;
...@@ -106,7 +117,7 @@ struct argument_parser ...@@ -106,7 +117,7 @@ struct argument_parser
}; };
} }
static auto write_range() static auto append()
{ {
return write_action([](auto&, auto& x, auto& params) { return write_action([](auto&, auto& x, auto& params) {
using type = typename decltype(params)::value_type; using type = typename decltype(params)::value_type;
...@@ -124,6 +135,11 @@ struct argument_parser ...@@ -124,6 +135,11 @@ struct argument_parser
{ {
std::cout << std::endl; std::cout << std::endl;
std::string prefix = " "; std::string prefix = " ";
if (arg.flags.empty())
{
std::cout << prefix;
std::cout << arg.metavar;
}
for(const std::string& a : arg.flags) for(const std::string& a : arg.flags)
{ {
std::cout << prefix; std::cout << prefix;
...@@ -146,10 +162,16 @@ struct argument_parser ...@@ -146,10 +162,16 @@ struct argument_parser
return [=](auto&, auto& arg) { arg.help = help; }; return [=](auto&, auto& arg) { arg.help = help; };
} }
static auto metavar(std::string metavar)
{
return [=](auto&, auto& arg) { arg.metavar = metavar; };
}
template <class T> template <class T>
static auto set_value(T value) static auto set_value(T value)
{ {
return [=](auto& x, auto& arg) { return [=](auto& x, auto& arg) {
arg.nargs = 0;
arg.type = ""; arg.type = "";
arg.action = [&, value](auto&, const std::vector<std::string>&) { arg.action = [&, value](auto&, const std::vector<std::string>&) {
x = value; x = value;
...@@ -158,25 +180,32 @@ struct argument_parser ...@@ -158,25 +180,32 @@ struct argument_parser
}; };
} }
void parse(std::vector<std::string> args) bool parse(std::vector<std::string> args)
{ {
std::set<std::string> keywords; std::unordered_map<std::string, unsigned> keywords;
for(auto&& arg : arguments) for(auto&& arg : arguments)
{ {
keywords.insert(arg.flags.begin(), arg.flags.end()); for(auto&& flag:arg.flags)
keywords[flag] = arg.nargs + 1;
} }
auto arg_map = generic_parse(args, [&](std::string x) { return (keywords.count(x) > 0); }); auto arg_map = generic_parse(args, [&](std::string x) {
return keywords[x];
});
for(auto&& arg : arguments) for(auto&& arg : arguments)
{ {
auto flags = arg.flags;
if (flags.empty())
flags = {""};
for(auto&& flag : arg.flags) for(auto&& flag : arg.flags)
{ {
if(arg_map.count(flag) > 0) if(arg_map.count(flag) > 0)
{ {
if(arg.action(*this, arg_map[flag])) if(arg.action(*this, arg_map[flag]))
return; return true;
} }
} }
} }
return false;
} }
using string_map = std::unordered_map<std::string, std::vector<std::string>>; using string_map = std::unordered_map<std::string, std::vector<std::string>>;
...@@ -186,16 +215,28 @@ struct argument_parser ...@@ -186,16 +215,28 @@ struct argument_parser
string_map result; string_map result;
std::string flag; std::string flag;
bool clear = false;
for(auto&& x : as) for(auto&& x : as)
{ {
if(is_keyword(x)) auto k = is_keyword(x);
if(k > 0)
{ {
flag = x; flag = x;
result[flag]; // Ensure the flag exists result[flag]; // Ensure the flag exists
if (k == 1)
flag = "";
else if (k == 2)
clear = true;
else
clear = false;
} }
else else
{ {
result[flag].push_back(x); result[flag].push_back(x);
if (clear)
flag = "";
clear = false;
} }
} }
return result; return result;
......
...@@ -26,16 +26,25 @@ std::string command_name() ...@@ -26,16 +26,25 @@ std::string command_name()
return name.substr(name.rfind("::") + 2); return name.substr(name.rfind("::") + 2);
} }
template <class T> template<class T>
int auto_register_command() void run_command(std::vector<std::string> args, bool add_help=false)
{ {
auto& m = get_commands();
m[command_name<T>()] = [](std::vector<std::string> args) {
T x; T x;
argument_parser ap; argument_parser ap;
if (add_help)
ap.add(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help());
x.parse(ap); x.parse(ap);
ap.parse(args); if (ap.parse(args))
return;
x.run(); x.run();
}
template <class T>
int auto_register_command()
{
auto& m = get_commands();
m[command_name<T>()] = [](std::vector<std::string> args) {
run_command<T>(args, true);
}; };
return 0; return 0;
} }
...@@ -50,6 +59,11 @@ struct command ...@@ -50,6 +59,11 @@ struct command
std::integral_constant<decltype(&static_register), &static_register>; std::integral_constant<decltype(&static_register), &static_register>;
}; };
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
template <class T> template <class T>
int command<T>::static_register = auto_register_command<T>(); // NOLINT int command<T>::static_register = auto_register_command<T>(); // NOLINT
......
#include "argument_parser.hpp" #include "argument_parser.hpp"
#include "command.hpp" #include "command.hpp"
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
struct loader
{
std::string file;
std::string type;
bool is_nhwc = false;
void parse(argument_parser& ap)
{
ap.add(file, {}, ap.metavar("<input file>"));
ap.add(type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap.add(type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
ap.add(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap.add(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
}
program load()
{
program p;
if (type.empty())
{
if (ends_with(file, ".onnx"))
type = "onnx";
else
type = "tf";
}
if (type == "onnx")
p = parse_onnx(file);
else if (type == "tf")
p = parse_tf(file, is_nhwc);
return p;
}
};
struct read : command<read>
{
loader l;
void parse(argument_parser& ap)
{
l.parse(ap);
}
void run()
{
auto p = l.load();
std::cout << p << std::endl;
}
};
struct main_command struct main_command
{ {
static std::string get_command_help() static std::string get_command_help()
{ {
std::string result = "Commands:\n"; std::string result = "Commands:\n";
for(const auto& p : migraphx::driver::get_commands()) for(const auto& p : get_commands())
result += " " + p.first + "\n"; result += " " + p.first + "\n";
return result; return result;
} }
void parse(migraphx::driver::argument_parser& ap) void parse(argument_parser& ap)
{ {
ap.add(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help())); ap.add(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
} }
...@@ -18,12 +74,17 @@ struct main_command ...@@ -18,12 +74,17 @@ struct main_command
void run() {} void run() {}
}; };
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
using namespace migraphx::driver;
int main(int argc, const char* argv[]) int main(int argc, const char* argv[])
{ {
std::vector<std::string> args(argv + 1, argv + argc); std::vector<std::string> args(argv + 1, argv + argc);
if(args.empty()) if(args.empty())
return 0; return 0;
auto&& m = migraphx::driver::get_commands(); auto&& m = get_commands();
auto cmd = args.front(); auto cmd = args.front();
if(m.count(cmd) > 0) if(m.count(cmd) > 0)
{ {
...@@ -31,11 +92,7 @@ int main(int argc, const char* argv[]) ...@@ -31,11 +92,7 @@ int main(int argc, const char* argv[])
} }
else else
{ {
migraphx::driver::argument_parser ap; run_command<main_command>(args);
main_command mc;
mc.parse(ap);
ap.parse(args);
mc.run();
} }
return 0; return 0;
} }
...@@ -15,35 +15,17 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT ...@@ -15,35 +15,17 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
template <bool B> template <bool B>
using bool_c = std::integral_constant<bool, B>; using bool_c = std::integral_constant<bool, B>;
template <int N> #define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y
struct requires_enum #define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y)
{
enum e
{
a = 0
};
};
#define MIGRAPHX_REQUIRES_CAT(x, y) x##y #define MIGRAPHX_REQUIRES_VAR() MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__)
#ifdef CPPCHECK #ifdef CPPCHECK
#define MIGRAPHX_REQUIRES(...) class = void #define MIGRAPHX_REQUIRES(...) class = void
#else #else
#if 0
// TODO: This currently crashed on clang
#define MIGRAPHX_REQUIRES(...) \ #define MIGRAPHX_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \ bool MIGRAPHX_REQUIRES_VAR()=true, \
PrivateRequires, \ typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() && (migraphx::and_<__VA_ARGS__>{})), int>::type = 0
__LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__, \
MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__) == \
migraphx::requires_enum<__LINE__>::a>{}>::type
#else
#define MIGRAPHX_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \
PrivateRequires, __LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<migraphx::and_<__VA_ARGS__>{}>::type
#endif
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment