Unverified Commit ebdddf58 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Improve help and error reporting in driver (#1258)

* Improve type printing in driver
* Improve error with incorrect order for command
* Add spell checking of arguments
* Add validations and required checking
* Add required arguments and groups
parent e2106d08
...@@ -27,11 +27,13 @@ ...@@ -27,11 +27,13 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <list>
#include <set> #include <set>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -39,9 +41,16 @@ ...@@ -39,9 +41,16 @@
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#ifndef _WIN32
#include <unistd.h>
#endif
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -74,6 +83,65 @@ template <class T> ...@@ -74,6 +83,65 @@ template <class T>
using is_multi_value = using is_multi_value =
std::integral_constant<bool, (is_container<T>{} and not std::is_convertible<T, std::string>{})>; std::integral_constant<bool, (is_container<T>{} and not std::is_convertible<T, std::string>{})>;
enum class color
{
reset = 0,
bold = 1,
underlined = 4,
fg_red = 31,
fg_green = 32,
fg_yellow = 33,
fg_blue = 34,
fg_default = 39,
bg_red = 41,
bg_green = 42,
bg_yellow = 43,
bg_blue = 44,
bg_default = 49
};
inline std::ostream& operator<<(std::ostream& os, const color& c)
{
#ifndef _WIN32
static const bool use_color = isatty(STDOUT_FILENO) != 0;
if(use_color)
return os << "\033[" << static_cast<std::size_t>(c) << "m";
#endif
return os;
}
inline std::string colorize(color c, const std::string& s)
{
std::stringstream ss;
ss << c << s << color::reset;
return ss.str();
}
template <class T>
struct type_name
{
static const std::string& apply() { return migraphx::get_type_name<T>(); }
};
template <>
struct type_name<std::string>
{
static const std::string& apply()
{
static const std::string name = "std::string";
return name;
}
};
template <class T>
struct type_name<std::vector<T>>
{
static const std::string& apply()
{
static const std::string name = "std::vector<" + type_name<T>::apply() + ">";
return name;
}
};
template <class T> template <class T>
struct value_parser struct value_parser
{ {
...@@ -85,7 +153,7 @@ struct value_parser ...@@ -85,7 +153,7 @@ struct value_parser
ss.str(x); ss.str(x);
ss >> result; ss >> result;
if(ss.fail()) if(ss.fail())
throw std::runtime_error("Failed to parse: " + x); throw std::runtime_error("Failed to parse '" + x + "' as " + type_name<T>::apply());
return result; return result;
} }
...@@ -97,7 +165,7 @@ struct value_parser ...@@ -97,7 +165,7 @@ struct value_parser
ss.str(x); ss.str(x);
ss >> i; ss >> i;
if(ss.fail()) if(ss.fail())
throw std::runtime_error("Failed to parse: " + x); throw std::runtime_error("Failed to parse '" + x + "' as " + type_name<T>::apply());
return static_cast<T>(i); return static_cast<T>(i);
} }
...@@ -115,13 +183,42 @@ struct argument_parser ...@@ -115,13 +183,42 @@ struct argument_parser
{ {
struct argument struct argument
{ {
using action_function =
std::function<bool(argument_parser&, const std::vector<std::string>&)>;
using validate_function =
std::function<void(const argument_parser&, const std::vector<std::string>&)>;
std::vector<std::string> flags; std::vector<std::string> flags;
std::function<bool(argument_parser&, const std::vector<std::string>&)> action{}; action_function action{};
std::string type = ""; std::string type = "";
std::string help = ""; std::string help = "";
std::string metavar = ""; std::string metavar = "";
std::string default_value = ""; std::string default_value = "";
std::string group = "";
unsigned nargs = 1; unsigned nargs = 1;
bool required = false;
std::vector<validate_function> validations{};
std::string usage(const std::string& flag) const
{
std::stringstream ss;
if(flag.empty())
{
ss << metavar;
}
else
{
ss << flag;
if(not type.empty())
ss << " [" << type << "]";
}
return ss.str();
}
std::string usage() const
{
if(flags.empty())
return usage("");
return usage(flags.front());
}
}; };
template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})> template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})>
...@@ -154,12 +251,14 @@ struct argument_parser ...@@ -154,12 +251,14 @@ struct argument_parser
arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) { arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) {
if(params.empty()) if(params.empty())
throw std::runtime_error("Flag with no value."); throw std::runtime_error("Flag with no value.");
if(not is_multi_value<T>{} and params.size() > 1)
throw std::runtime_error("Too many arguments passed.");
x = value_parser<T>::apply(params.back()); x = value_parser<T>::apply(params.back());
return false; return false;
}}); }});
argument& arg = arguments.back(); argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>(); arg.type = type_name<T>::apply();
migraphx::each_args([&](auto f) { f(x, arg); }, fs...); migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
if(not arg.default_value.empty() and arg.nargs > 0) if(not arg.default_value.empty() and arg.nargs > 0)
arg.default_value = as_string_value(x); arg.default_value = as_string_value(x);
...@@ -181,6 +280,11 @@ struct argument_parser ...@@ -181,6 +280,11 @@ struct argument_parser
return [=](auto&&, auto& arg) { arg.nargs = n; }; return [=](auto&&, auto& arg) { arg.nargs = n; };
} }
MIGRAPHX_DRIVER_STATIC auto required()
{
return [=](auto&&, auto& arg) { arg.required = true; };
}
template <class F> template <class F>
MIGRAPHX_DRIVER_STATIC auto write_action(F f) MIGRAPHX_DRIVER_STATIC auto write_action(F f)
{ {
...@@ -215,13 +319,141 @@ struct argument_parser ...@@ -215,13 +319,141 @@ struct argument_parser
}); });
} }
MIGRAPHX_DRIVER_STATIC auto show_help(const std::string& msg = "") template <class F>
MIGRAPHX_DRIVER_STATIC auto validate(F f)
{
return [=](const auto& x, auto& arg) {
arg.validations.push_back(
[&, f](auto& self, const std::vector<std::string>& params) { f(self, x, params); });
};
}
MIGRAPHX_DRIVER_STATIC auto file_exist()
{
return validate([](auto&, auto&, auto& params) {
if(params.empty())
throw std::runtime_error("No argument passed.");
if(not fs::exists(params.back()))
throw std::runtime_error("Path does not exists: " + params.back());
});
}
template <class F>
argument* find_argument(F f)
{
auto it = std::find_if(arguments.begin(), arguments.end(), f);
if(it == arguments.end())
return nullptr;
return std::addressof(*it);
}
template <class F>
bool has_argument(F f)
{
return find_argument(f) != nullptr;
}
template <class F>
std::vector<argument*> find_arguments(F f)
{
std::vector<argument*> result;
for(auto& arg : arguments)
{
if(not f(arg))
continue;
result.push_back(&arg);
}
return result;
}
std::vector<argument*> get_group_arguments(const std::string& group)
{
return find_arguments([&](const auto& arg) { return arg.group == group; });
}
std::vector<argument*> get_required_arguments()
{
return find_arguments([&](const auto& arg) { return arg.required; });
}
template <class SequenceContainer>
std::vector<std::string> get_argument_usages(SequenceContainer args)
{
std::vector<std::string> usage_flags;
std::unordered_set<std::string> found_groups;
// Remove arguments that belong to a group
auto it = std::remove_if(args.begin(), args.end(), [&](const argument* arg) {
if(arg->group.empty())
return false;
found_groups.insert(arg->group);
return true;
});
args.erase(it, args.end());
transform(found_groups, std::back_inserter(usage_flags), [&](auto&& group) {
std::vector<std::string> either_flags;
transform(get_group_arguments(group), std::back_inserter(either_flags), [](auto* arg) {
return arg->usage();
});
return "(" + join_strings(either_flags, "|") + ")";
});
transform(args, std::back_inserter(usage_flags), [&](auto* arg) { return arg->usage(); });
return usage_flags;
}
auto show_help(const std::string& msg = "")
{ {
return do_action([=](auto& self) { return do_action([=](auto& self) {
argument* input_argument =
self.find_argument([](const auto& arg) { return arg.flags.empty(); });
auto required_usages = get_argument_usages(get_required_arguments());
if(required_usages.empty() && input_argument)
required_usages.push_back(input_argument->metavar);
required_usages.insert(required_usages.begin(), "<options>");
print_usage(required_usages);
std::cout << std::endl;
if(self.find_argument([](const auto& arg) { return arg.nargs == 0; }))
{
std::cout << color::fg_yellow << "FLAGS:" << color::reset << std::endl;
std::cout << std::endl;
for(auto&& arg : self.arguments) for(auto&& arg : self.arguments)
{
if(arg.nargs != 0)
continue;
const int col_align = 35;
std::string prefix = " ";
int len = 0;
std::cout << color::fg_green;
for(const std::string& a : arg.flags)
{
len += prefix.length() + a.length();
std::cout << prefix;
std::cout << a;
prefix = ", ";
}
std::cout << color::reset;
int spaces = col_align - len;
if(spaces < 0)
{ {
std::cout << std::endl; std::cout << std::endl;
}
else
{
for(int i = 0; i < spaces; i++)
std::cout << " ";
}
std::cout << arg.help << std::endl;
}
std::cout << std::endl;
}
if(self.find_argument([](const auto& arg) { return arg.nargs != 0; }))
{
std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl;
for(auto&& arg : self.arguments)
{
if(arg.nargs == 0)
continue;
std::cout << std::endl;
std::string prefix = " "; std::string prefix = " ";
std::cout << color::fg_green;
if(arg.flags.empty()) if(arg.flags.empty())
{ {
std::cout << prefix; std::cout << prefix;
...@@ -233,9 +465,10 @@ struct argument_parser ...@@ -233,9 +465,10 @@ struct argument_parser
std::cout << a; std::cout << a;
prefix = ", "; prefix = ", ";
} }
std::cout << color::reset;
if(not arg.type.empty()) if(not arg.type.empty())
{ {
std::cout << " [" << arg.type << "]"; std::cout << " [" << color::fg_blue << arg.type << color::reset << "]";
if(not arg.default_value.empty()) if(not arg.default_value.empty())
std::cout << " (Default: " << arg.default_value << ")"; std::cout << " (Default: " << arg.default_value << ")";
} }
...@@ -243,6 +476,7 @@ struct argument_parser ...@@ -243,6 +476,7 @@ struct argument_parser
std::cout << " " << arg.help << std::endl; std::cout << " " << arg.help << std::endl;
} }
std::cout << std::endl; std::cout << std::endl;
}
if(not msg.empty()) if(not msg.empty())
std::cout << msg << std::endl; std::cout << msg << std::endl;
}); });
...@@ -263,6 +497,11 @@ struct argument_parser ...@@ -263,6 +497,11 @@ struct argument_parser
return [=](auto&, auto& arg) { arg.type = type; }; return [=](auto&, auto& arg) { arg.type = type; };
} }
MIGRAPHX_DRIVER_STATIC auto group(const std::string& group)
{
return [=](auto&, auto& arg) { arg.group = group; };
}
template <class T> template <class T>
MIGRAPHX_DRIVER_STATIC auto set_value(T value) MIGRAPHX_DRIVER_STATIC auto set_value(T value)
{ {
...@@ -276,6 +515,109 @@ struct argument_parser ...@@ -276,6 +515,109 @@ struct argument_parser
}; };
} }
template <class T>
void set_exe_name_to(T& x)
{
actions.push_back([&](const auto& self) { x = self.exe_name; });
}
void print_try_help()
{
if(has_argument([](const auto& a) { return contains(a.flags, "--help"); }))
{
std::cout << std::endl;
std::cout << "For more information try '" << color::fg_green << "--help" << color::reset
<< "'" << std::endl;
}
}
void print_usage(const std::vector<std::string>& flags) const
{
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " " << exe_name << " ";
std::cout << join_strings(flags, " ") << std::endl;
}
auto spellcheck(const std::vector<std::string>& inputs)
{
struct result_t
{
const argument* arg = nullptr;
std::string correct = "";
std::string incorrect = "";
std::ptrdiff_t distance = std::numeric_limits<std::ptrdiff_t>::max();
};
result_t result;
for(const auto& input : inputs)
{
if(input.empty())
continue;
if(input[0] != '-')
continue;
for(const auto& arg : arguments)
{
for(const auto& flag : arg.flags)
{
if(flag.empty())
continue;
if(flag[0] != '-')
continue;
auto d =
levenshtein_distance(flag.begin(), flag.end(), input.begin(), input.end());
if(d < result.distance)
result = result_t{&arg, flag, input, d};
}
}
}
return result;
}
bool
run_action(const argument& arg, const std::string& flag, const std::vector<std::string>& inputs)
{
std::string msg = "";
try
{
for(const auto& v : arg.validations)
v(*this, inputs);
return arg.action(*this, inputs);
}
catch(const std::exception& e)
{
msg = e.what();
}
catch(...)
{
msg = "unknown exception";
}
std::cout << color::fg_red << color::bold << "error: " << color::reset;
auto sc = spellcheck(inputs);
if(sc.distance < 5)
{
std::cout << "Found argument '" << color::fg_yellow << sc.incorrect << color::reset
<< "'"
<< " which wasn't expected, or isn't valid in this context" << std::endl;
std::cout << " "
<< "Did you mean " << color::fg_green << sc.correct << color::reset << "?"
<< std::endl;
std::cout << std::endl;
print_usage({sc.arg->usage(sc.correct)});
}
else
{
const auto& flag_name = flag.empty() ? arg.metavar : flag;
std::cout << "Invalid input to '" << color::fg_yellow;
std::cout << arg.usage(flag_name);
std::cout << color::reset << "'" << std::endl;
std::cout << " " << msg << std::endl;
std::cout << std::endl;
print_usage({arg.usage()});
}
std::cout << std::endl;
print_try_help();
return true;
}
bool parse(std::vector<std::string> args) bool parse(std::vector<std::string> args)
{ {
std::unordered_map<std::string, unsigned> keywords; std::unordered_map<std::string, unsigned> keywords;
...@@ -286,8 +628,11 @@ struct argument_parser ...@@ -286,8 +628,11 @@ struct argument_parser
} }
auto arg_map = auto arg_map =
generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; }); generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; });
std::list<const argument*> missing_arguments;
std::unordered_set<std::string> groups_used;
for(auto&& arg : arguments) for(auto&& arg : arguments)
{ {
bool used = false;
auto flags = arg.flags; auto flags = arg.flags;
if(flags.empty()) if(flags.empty())
flags = {""}; flags = {""};
...@@ -295,14 +640,41 @@ struct argument_parser ...@@ -295,14 +640,41 @@ struct argument_parser
{ {
if(arg_map.count(flag) > 0) if(arg_map.count(flag) > 0)
{ {
if(arg.action(*this, arg_map[flag])) if(run_action(arg, flag, arg_map[flag]))
return true; return true;
used = true;
} }
} }
if(used and not arg.group.empty())
groups_used.insert(arg.group);
if(arg.required and not used)
missing_arguments.push_back(&arg);
} }
// Remove arguments from a group that is being used
missing_arguments.remove_if(
[&](const argument* arg) { return groups_used.count(arg->group); });
if(not missing_arguments.empty())
{
std::cout << color::fg_red << color::bold << "error: " << color::reset;
std::cout << "The following required arguments were not provided:" << std::endl;
std::cout << " " << color::fg_red
<< join_strings(get_argument_usages(std::move(missing_arguments)), " ")
<< color::reset << std::endl;
std::cout << std::endl;
auto required_usages = get_argument_usages(get_required_arguments());
print_usage(required_usages);
print_try_help();
return true;
}
for(auto&& action : actions)
action(*this);
return false; return false;
} }
void set_exe_name(const std::string& s) { exe_name = s; }
const std::string& get_exe_name() const { return exe_name; }
using string_map = std::unordered_map<std::string, std::vector<std::string>>; using string_map = std::unordered_map<std::string, std::vector<std::string>>;
template <class IsKeyword> template <class IsKeyword>
static string_map generic_parse(std::vector<std::string> as, IsKeyword is_keyword) static string_map generic_parse(std::vector<std::string> as, IsKeyword is_keyword)
...@@ -337,7 +709,9 @@ struct argument_parser ...@@ -337,7 +709,9 @@ struct argument_parser
} }
private: private:
std::vector<argument> arguments; std::list<argument> arguments;
std::string exe_name = "";
std::vector<std::function<void(argument_parser&)>> actions;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -41,7 +41,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -41,7 +41,10 @@ inline namespace MIGRAPHX_INLINE_NS {
inline auto& get_commands() inline auto& get_commands()
{ {
// NOLINTNEXTLINE // NOLINTNEXTLINE
static std::unordered_map<std::string, std::function<void(std::vector<std::string> args)>> m; static std::unordered_map<
std::string,
std::function<void(const std::string& exe_name, std::vector<std::string> args)>>
m;
return m; return m;
} }
...@@ -65,10 +68,11 @@ const std::string& command_name() ...@@ -65,10 +68,11 @@ const std::string& command_name()
} }
template <class T> template <class T>
void run_command(std::vector<std::string> args, bool add_help = false) void run_command(const std::string& exe_name, std::vector<std::string> args, bool add_help = false)
{ {
T x; T x;
argument_parser ap; argument_parser ap;
ap.set_exe_name(exe_name + " " + command_name<T>());
if(add_help) if(add_help)
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help()); ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help());
x.parse(ap); x.parse(ap);
...@@ -81,7 +85,9 @@ template <class T> ...@@ -81,7 +85,9 @@ template <class T>
int auto_register_command() int auto_register_command()
{ {
auto& m = get_commands(); auto& m = get_commands();
m[command_name<T>()] = [](std::vector<std::string> args) { run_command<T>(args, true); }; m[command_name<T>()] = [](const std::string& exe_name, std::vector<std::string> args) {
run_command<T>(exe_name, args, true);
};
return 0; return 0;
} }
......
...@@ -73,8 +73,12 @@ struct loader ...@@ -73,8 +73,12 @@ struct loader
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
ap(file, {}, ap.metavar("<input file>")); ap(file, {}, ap.metavar("<input file>"), ap.file_exist(), ap.required(), ap.group("input"));
ap(model, {"--model"}, ap.help("Load model"), ap.type("resnet50|inceptionv3|alexnet")); ap(model,
{"--model"},
ap.help("Load model"),
ap.type("resnet50|inceptionv3|alexnet"),
ap.group("input"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx")); ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf")); ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx")); ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx"));
...@@ -578,26 +582,62 @@ struct onnx : command<onnx> ...@@ -578,26 +582,62 @@ struct onnx : command<onnx>
struct main_command struct main_command
{ {
static std::string get_command_help() static std::string get_command_help(const std::string& title = colorize(color::fg_yellow,
"COMMANDS:"))
{ {
std::string result = "Commands:\n"; std::string result = title + "\n";
return std::accumulate(get_commands().begin(), std::vector<std::string> commands(get_commands().size());
std::transform(get_commands().begin(),
get_commands().end(), get_commands().end(),
result, commands.begin(),
[](auto r, auto&& p) { return r + " " + p.first + "\n"; }); [](const auto& p) { return colorize(color::fg_green, p.first); });
std::sort(commands.begin(), commands.end());
return std::accumulate(commands.begin(), commands.end(), result, [](auto r, auto&& s) {
return r + " " + s + "\n";
});
} }
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
std::string version_str = "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) + std::string version_str = "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) +
"." + std::to_string(MIGRAPHX_VERSION_MINOR); "." + std::to_string(MIGRAPHX_VERSION_MINOR);
ap(wrong_commands, {}, ap.metavar("<command>"), ap.append());
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help())); ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
ap(nullptr, ap(nullptr,
{"-v", "--version"}, {"-v", "--version"},
ap.help("Show MIGraphX version"), ap.help("Show MIGraphX version"),
ap.show_help(version_str)); ap.show_help(version_str));
// Trim command off of exe name
ap.set_exe_name(ap.get_exe_name().substr(0, ap.get_exe_name().size() - 5));
ap.set_exe_name_to(exe_name);
} }
void run() {} std::vector<std::string> wrong_commands{};
std::string exe_name = "<exe>";
void run()
{
std::cout << color::fg_red << color::bold << "error: " << color::reset;
auto it = std::find_if(wrong_commands.begin(), wrong_commands.end(), [](const auto& c) {
return get_commands().count(c) > 0;
});
if(it == wrong_commands.end())
{
std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
<< "' is not a valid command." << std::endl;
std::cout << get_command_help("Available commands:") << std::endl;
}
else
{
std::cout << "command '" << color::fg_yellow << *it << color::reset
<< "' must be first argument" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " " << exe_name << " " << *it << " <options>" << std::endl;
}
std::cout << std::endl;
}
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
...@@ -619,11 +659,11 @@ int main(int argc, const char* argv[]) ...@@ -619,11 +659,11 @@ int main(int argc, const char* argv[])
auto cmd = args.front(); auto cmd = args.front();
if(m.count(cmd) > 0) if(m.count(cmd) > 0)
{ {
m.at(cmd)({args.begin() + 1, args.end()}); m.at(cmd)(argv[0], {args.begin() + 1, args.end()});
} }
else else
{ {
run_command<main_command>(args); run_command<main_command>(argv[0], args);
} }
return 0; return 0;
......
...@@ -74,6 +74,22 @@ void group_unique(Iterator start, Iterator last, Output out, Predicate pred) ...@@ -74,6 +74,22 @@ void group_unique(Iterator start, Iterator last, Output out, Predicate pred)
} }
} }
template <class Iterator1, class Iterator2>
std::ptrdiff_t
levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterator2 last2)
{
if(first1 == last1)
return std::distance(first2, last2);
if(first2 == last2)
return std::distance(first1, last1);
if(*first1 == *first2)
return levenshtein_distance(std::next(first1), last1, std::next(first2), last2);
auto x1 = levenshtein_distance(std::next(first1), last1, std::next(first2), last2);
auto x2 = levenshtein_distance(first1, last1, std::next(first2), last2);
auto x3 = levenshtein_distance(std::next(first1), last1, first2, last2);
return std::ptrdiff_t{1} + std::min({x1, x2, x3});
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment