/* * The MIT License (MIT) * * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #ifndef MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP #define MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifndef _WIN32 #include #endif namespace migraphx { namespace driver { inline namespace MIGRAPHX_INLINE_NS { #ifdef MIGRAPHX_USE_CLANG_TIDY #define MIGRAPHX_DRIVER_STATIC #else #define MIGRAPHX_DRIVER_STATIC static #endif template using bare = std::remove_cv_t>; namespace detail { template auto is_container(int, T&& x) -> decltype(x.insert(x.end(), *x.begin()), std::true_type{}); template std::false_type is_container(float, T&&); } // namespace detail template struct is_container : decltype(detail::is_container(int(0), std::declval())) { }; template using is_multi_value = std::integral_constant{} and not std::is_convertible{})>; enum class color { reset = 0, bold = 1, underlined = 4, fg_red = 31, fg_green = 32, fg_yellow = 33, fg_blue = 34, fg_default = 39, bg_red = 41, bg_green = 42, bg_yellow = 43, bg_blue = 44, bg_default = 49 }; inline std::ostream& operator<<(std::ostream& os, const color& c) { #ifndef _WIN32 static const bool use_color = isatty(STDOUT_FILENO) != 0; if(use_color) return os << "\033[" << static_cast(c) << "m"; #endif return os; } inline std::string colorize(color c, const std::string& s) { std::stringstream ss; ss << c << s << color::reset; return ss.str(); } template struct type_name { static const std::string& apply() { return migraphx::get_type_name(); } }; template <> struct type_name { static const std::string& apply() { static const std::string name = "std::string"; return name; } }; template struct type_name> { static const std::string& apply() { static const std::string name = "std::vector<" + type_name::apply() + ">"; return name; } }; template struct value_parser { template {} and not is_multi_value{})> static T apply(const std::string& x) { T result; std::stringstream ss; ss.str(x); ss >> result; if(ss.fail()) throw std::runtime_error("Failed to parse '" + x + "' as " + type_name::apply()); return result; } template {} and not is_multi_value{})> static T apply(const std::string& x) { std::ptrdiff_t i; std::stringstream ss; ss.str(x); ss >> i; if(ss.fail()) throw std::runtime_error("Failed to parse '" + x + "' as " + type_name::apply()); return static_cast(i); } template {} and not std::is_enum{})> static T apply(const std::string& x) { T result; using value_type = typename T::value_type; result.insert(result.end(), value_parser::apply(x)); return result; } }; struct argument_parser { struct argument { using action_function = std::function&)>; using validate_function = std::function&)>; std::vector flags; action_function action{}; std::string type = ""; std::string help = ""; std::string metavar = ""; std::string default_value = ""; unsigned nargs = 1; bool required = true; std::vector validations{}; }; template {})> std::string as_string_value(const T& x) { return to_string_range(x); } template auto as_string_value(rank<1>, const T& x) -> decltype(to_string(x)) { return to_string(x); } template std::string as_string_value(rank<0>, const T&) { throw std::runtime_error("Can't convert to string"); } template {})> std::string as_string_value(const T& x) { return as_string_value(rank<1>{}, x); } template void operator()(T& x, const std::vector& flags, Fs... fs) { arguments.push_back({flags, [&](auto&&, const std::vector& params) { if(params.empty()) throw std::runtime_error("Flag with no value."); if(not is_multi_value{} and params.size() > 1) throw std::runtime_error("Too many arguments passed."); x = value_parser::apply(params.back()); return false; }}); argument& arg = arguments.back(); arg.type = type_name::apply(); migraphx::each_args([&](auto f) { f(x, arg); }, fs...); if(not arg.default_value.empty() and arg.nargs > 0) arg.default_value = as_string_value(x); } template void operator()(std::nullptr_t x, std::vector flags, Fs... fs) { arguments.push_back({std::move(flags)}); argument& arg = arguments.back(); arg.type = ""; arg.nargs = 0; migraphx::each_args([&](auto f) { f(x, arg); }, fs...); } MIGRAPHX_DRIVER_STATIC auto nargs(unsigned n = 1) { return [=](auto&&, auto& arg) { arg.nargs = n; }; } MIGRAPHX_DRIVER_STATIC auto required() { return [=](auto&&, auto& arg) { arg.required = true; }; } template MIGRAPHX_DRIVER_STATIC auto write_action(F f) { return [=](auto& x, auto& arg) { arg.action = [&, f](auto& self, const std::vector& params) { f(self, x, params); return false; }; }; } template MIGRAPHX_DRIVER_STATIC auto do_action(F f) { return [=](auto&, auto& arg) { arg.nargs = 0; arg.action = [&, f](auto& self, const std::vector&) { f(self); return true; }; }; } MIGRAPHX_DRIVER_STATIC auto append() { return write_action([](auto&, auto& x, auto& params) { using type = typename bare::value_type; std::transform(params.begin(), params.end(), std::inserter(x, x.end()), [](std::string y) { return value_parser::apply(y); }); }); } template MIGRAPHX_DRIVER_STATIC auto validate(F f) { return [=](const auto& x, auto& arg) { arg.validations.push_back( [&, f](auto& self, const std::vector& params) { f(self, x, params); }); }; } MIGRAPHX_DRIVER_STATIC auto file_exist() { return validate([](auto&, auto&, auto& params) { if(params.empty()) throw std::runtime_error("No argument passed."); if(not fs::exists(params.back())) throw std::runtime_error("Path does not exists: " + params.back()); }); } template argument* find_argument(F f) { auto it = std::find_if(arguments.begin(), arguments.end(), f); if(it == arguments.end()) return nullptr; return std::addressof(*it); } template bool has_argument(F f) { return find_argument(f) != nullptr; } MIGRAPHX_DRIVER_STATIC auto show_help(const std::string& msg = "") { return do_action([=](auto& self) { argument* input_argument = self.find_argument([](const auto& arg) { return arg.flags.empty(); }); std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl; std::cout << " " << self.exe_name << " "; if(input_argument) std::cout << input_argument->metavar; std::cout << std::endl; std::cout << std::endl; if(self.find_argument([](const auto& arg) { return arg.nargs == 0; })) { std::cout << color::fg_yellow << "FLAGS:" << color::reset << std::endl; std::cout << std::endl; for(auto&& arg : self.arguments) { if(arg.nargs != 0) continue; const int col_align = 35; std::string prefix = " "; int len = 0; std::cout << color::fg_green; for(const std::string& a : arg.flags) { len += prefix.length() + a.length(); std::cout << prefix; std::cout << a; prefix = ", "; } std::cout << color::reset; int spaces = col_align - len; if(spaces < 0) { std::cout << std::endl; } else { for(int i = 0; i < spaces; i++) std::cout << " "; } std::cout << arg.help << std::endl; } std::cout << std::endl; } if(self.find_argument([](const auto& arg) { return arg.nargs != 0; })) { std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl; for(auto&& arg : self.arguments) { if(arg.nargs == 0) continue; std::cout << std::endl; std::string prefix = " "; std::cout << color::fg_green; if(arg.flags.empty()) { std::cout << prefix; std::cout << arg.metavar; } for(const std::string& a : arg.flags) { std::cout << prefix; std::cout << a; prefix = ", "; } std::cout << color::reset; if(not arg.type.empty()) { std::cout << " [" << color::fg_blue << arg.type << color::reset << "]"; if(not arg.default_value.empty()) std::cout << " (Default: " << arg.default_value << ")"; } std::cout << std::endl; std::cout << " " << arg.help << std::endl; } std::cout << std::endl; } if(not msg.empty()) std::cout << msg << std::endl; }); } MIGRAPHX_DRIVER_STATIC auto help(const std::string& help) { return [=](auto&, auto& arg) { arg.help = help; }; } MIGRAPHX_DRIVER_STATIC auto metavar(const std::string& metavar) { return [=](auto&, auto& arg) { arg.metavar = metavar; }; } MIGRAPHX_DRIVER_STATIC auto type(const std::string& type) { return [=](auto&, auto& arg) { arg.type = type; }; } template MIGRAPHX_DRIVER_STATIC auto set_value(T value) { return [=](auto& x, auto& arg) { arg.nargs = 0; arg.type = ""; arg.action = [&, value](auto&, const std::vector&) { x = value; return false; }; }; } template void set_exe_name_to(T& x) { actions.push_back([&](const auto& self) { x = self.exe_name; }); } void print_usage_for(const argument& arg, const std::string& flag) const { std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl; std::cout << " " << exe_name << " "; if(flag.empty()) { std::cout << arg.metavar; } else { std::cout << flag; if(not arg.type.empty()) std::cout << " [" << arg.type << "]"; } std::cout << std::endl; } auto spellcheck(const std::vector& inputs) { struct result_t { const argument* arg = nullptr; std::string correct = ""; std::string incorrect = ""; std::ptrdiff_t distance = std::numeric_limits::max(); }; result_t result; for(const auto& input : inputs) { if(input.empty()) continue; if(input[0] != '-') continue; for(const auto& arg : arguments) { for(const auto& flag : arg.flags) { if(flag.empty()) continue; if(flag[0] != '-') continue; auto d = levenshtein_distance(flag.begin(), flag.end(), input.begin(), input.end()); if(d < result.distance) result = result_t{&arg, flag, input, d}; } } } return result; } bool run_action(const argument& arg, const std::string& flag, const std::vector& inputs) { std::string msg = ""; try { for(const auto& v : arg.validations) v(*this, inputs); return arg.action(*this, inputs); } catch(const std::exception& e) { msg = e.what(); } catch(...) { msg = "unknown exception"; } std::cout << color::fg_red << color::bold << "error: " << color::reset; auto sc = spellcheck(inputs); if(sc.distance < 5) { std::cout << "Found argument '" << color::fg_yellow << sc.incorrect << color::reset << "'" << " which wasn't expected, or isn't valid in this context" << std::endl; std::cout << " " << "Did you mean " << color::fg_green << sc.correct << color::reset << "?" << std::endl; std::cout << std::endl; print_usage_for(*sc.arg, sc.correct); } else { const auto& flag_name = flag.empty() ? arg.metavar : flag; std::cout << "Invalid input to '" << color::fg_yellow; std::cout << flag_name; if(not arg.type.empty()) std::cout << " [" << arg.type << "]"; std::cout << color::reset << "'" << std::endl; std::cout << " " << msg << std::endl; std::cout << std::endl; print_usage_for(arg, flag); } std::cout << std::endl; if(has_argument([](const auto& a) { return contains(a.flags, "--help"); })) { std::cout << std::endl; std::cout << "For more information try '" << color::fg_green << "--help" << color::reset << "'" << std::endl; } return true; } bool parse(std::vector args) { std::unordered_map keywords; for(auto&& arg : arguments) { for(auto&& flag : arg.flags) keywords[flag] = arg.nargs + 1; } auto arg_map = generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; }); for(auto&& arg : arguments) { bool used = false; auto flags = arg.flags; if(flags.empty()) flags = {""}; for(auto&& flag : flags) { if(arg_map.count(flag) > 0) { if(run_action(arg, flag, arg_map[flag])) return true; used = true; } } } for(auto&& action : actions) action(*this); return false; } void set_exe_name(const std::string& s) { exe_name = s; } const std::string& get_exe_name() const { return exe_name; } using string_map = std::unordered_map>; template static string_map generic_parse(std::vector as, IsKeyword is_keyword) { string_map result; std::string flag; bool clear = false; for(auto&& x : as) { auto k = is_keyword(x); if(k > 0) { flag = x; result[flag]; // Ensure the flag exists if(k == 1) flag = ""; else if(k == 2) clear = true; else clear = false; } else { result[flag].push_back(x); if(clear) flag = ""; clear = false; } } return result; } private: std::list arguments; std::string exe_name = ""; std::vector> actions; }; } // namespace MIGRAPHX_INLINE_NS } // namespace driver } // namespace migraphx #endif