Commit f6216e40 authored by Paul's avatar Paul
Browse files

Add validations and required checking

parent a150ce8f
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
...@@ -181,13 +182,17 @@ struct argument_parser ...@@ -181,13 +182,17 @@ struct argument_parser
{ {
struct argument struct argument
{ {
using action_function = std::function<bool(argument_parser&, const std::vector<std::string>&)>;
using validate_function = std::function<void(const argument_parser&, const std::vector<std::string>&)>;
std::vector<std::string> flags; std::vector<std::string> flags;
std::function<bool(argument_parser&, const std::vector<std::string>&)> action{}; action_function action{};
std::string type = ""; std::string type = "";
std::string help = ""; std::string help = "";
std::string metavar = ""; std::string metavar = "";
std::string default_value = ""; std::string default_value = "";
unsigned nargs = 1; unsigned nargs = 1;
bool required = true;
std::vector<validate_function> validations{};
}; };
template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})> template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})>
...@@ -220,6 +225,8 @@ struct argument_parser ...@@ -220,6 +225,8 @@ struct argument_parser
arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) { arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) {
if(params.empty()) if(params.empty())
throw std::runtime_error("Flag with no value."); throw std::runtime_error("Flag with no value.");
if (not is_multi_value<T>{} and params.size() > 1)
throw std::runtime_error("Too many arguments passed.");
x = value_parser<T>::apply(params.back()); x = value_parser<T>::apply(params.back());
return false; return false;
}}); }});
...@@ -247,6 +254,11 @@ struct argument_parser ...@@ -247,6 +254,11 @@ struct argument_parser
return [=](auto&&, auto& arg) { arg.nargs = n; }; return [=](auto&&, auto& arg) { arg.nargs = n; };
} }
MIGRAPHX_DRIVER_STATIC auto required()
{
return [=](auto&&, auto& arg) { arg.required = true; };
}
template <class F> template <class F>
MIGRAPHX_DRIVER_STATIC auto write_action(F f) MIGRAPHX_DRIVER_STATIC auto write_action(F f)
{ {
...@@ -281,6 +293,26 @@ struct argument_parser ...@@ -281,6 +293,26 @@ struct argument_parser
}); });
} }
template <class F>
MIGRAPHX_DRIVER_STATIC auto validate(F f)
{
return [=](const auto& x, auto& arg) {
arg.validations.push_back([&, f](auto& self, const std::vector<std::string>& params) {
f(self, x, params);
});
};
}
MIGRAPHX_DRIVER_STATIC auto file_exist()
{
return validate([](auto&, auto&, auto& params) {
if (params.empty())
throw std::runtime_error("No argument passed.");
if (not fs::exists(params.back()))
throw std::runtime_error("Path does not exists: " + params.back());
});
}
template <class F> template <class F>
argument* find_argument(F f) argument* find_argument(F f)
{ {
...@@ -416,9 +448,16 @@ struct argument_parser ...@@ -416,9 +448,16 @@ struct argument_parser
{ {
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl; std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " " << exe_name << " "; std::cout << " " << exe_name << " ";
std::cout << flag; if (flag.empty())
if(not arg.type.empty()) {
std::cout << " [" << arg.type << "]"; std::cout << arg.metavar;
}
else
{
std::cout << flag;
if(not arg.type.empty())
std::cout << " [" << arg.type << "]";
}
std::cout << std::endl; std::cout << std::endl;
} }
...@@ -462,6 +501,8 @@ struct argument_parser ...@@ -462,6 +501,8 @@ struct argument_parser
std::string msg = ""; std::string msg = "";
try try
{ {
for(const auto& v:arg.validations)
v(*this, inputs);
return arg.action(*this, inputs); return arg.action(*this, inputs);
} }
catch(const std::exception& e) catch(const std::exception& e)
...@@ -487,8 +528,9 @@ struct argument_parser ...@@ -487,8 +528,9 @@ struct argument_parser
} }
else else
{ {
const auto& flag_name = flag.empty() ? arg.metavar : flag;
std::cout << "Invalid input to '" << color::fg_yellow; std::cout << "Invalid input to '" << color::fg_yellow;
std::cout << flag; std::cout << flag_name;
if(not arg.type.empty()) if(not arg.type.empty())
std::cout << " [" << arg.type << "]"; std::cout << " [" << arg.type << "]";
std::cout << color::reset << "'" << std::endl; std::cout << color::reset << "'" << std::endl;
...@@ -518,6 +560,7 @@ struct argument_parser ...@@ -518,6 +560,7 @@ struct argument_parser
generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; }); generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; });
for(auto&& arg : arguments) for(auto&& arg : arguments)
{ {
bool used = false;
auto flags = arg.flags; auto flags = arg.flags;
if(flags.empty()) if(flags.empty())
flags = {""}; flags = {""};
...@@ -527,6 +570,7 @@ struct argument_parser ...@@ -527,6 +570,7 @@ struct argument_parser
{ {
if(run_action(arg, flag, arg_map[flag])) if(run_action(arg, flag, arg_map[flag]))
return true; return true;
used = true;
} }
} }
} }
......
...@@ -73,7 +73,7 @@ struct loader ...@@ -73,7 +73,7 @@ struct loader
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
ap(file, {}, ap.metavar("<input file>")); ap(file, {}, ap.metavar("<input file>"), ap.file_exist());
ap(model, {"--model"}, ap.help("Load model"), ap.type("resnet50|inceptionv3|alexnet")); ap(model, {"--model"}, ap.help("Load model"), ap.type("resnet50|inceptionv3|alexnet"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx")); ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf")); ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment