Commit b878f78f authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into rewrite-fast-gelu

parents 3b414cc2 55cb7d3a
......@@ -41,7 +41,10 @@ inline namespace MIGRAPHX_INLINE_NS {
inline auto& get_commands()
{
// NOLINTNEXTLINE
static std::unordered_map<std::string, std::function<void(std::vector<std::string> args)>> m;
static std::unordered_map<
std::string,
std::function<void(const std::string& exe_name, std::vector<std::string> args)>>
m;
return m;
}
......@@ -65,10 +68,11 @@ const std::string& command_name()
}
template <class T>
void run_command(std::vector<std::string> args, bool add_help = false)
void run_command(const std::string& exe_name, std::vector<std::string> args, bool add_help = false)
{
T x;
argument_parser ap;
ap.set_exe_name(exe_name + " " + command_name<T>());
if(add_help)
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help());
x.parse(ap);
......@@ -81,7 +85,9 @@ template <class T>
int auto_register_command()
{
auto& m = get_commands();
m[command_name<T>()] = [](std::vector<std::string> args) { run_command<T>(args, true); };
m[command_name<T>()] = [](const std::string& exe_name, std::vector<std::string> args) {
run_command<T>(exe_name, args, true);
};
return 0;
}
......
......@@ -73,8 +73,12 @@ struct loader
void parse(argument_parser& ap)
{
ap(file, {}, ap.metavar("<input file>"));
ap(model, {"--model"}, ap.help("Load model"), ap.type("resnet50|inceptionv3|alexnet"));
ap(file, {}, ap.metavar("<input file>"), ap.file_exist(), ap.required(), ap.group("input"));
ap(model,
{"--model"},
ap.help("Load model"),
ap.type("resnet50|inceptionv3|alexnet"),
ap.group("input"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx"));
......@@ -578,26 +582,62 @@ struct onnx : command<onnx>
struct main_command
{
static std::string get_command_help()
static std::string get_command_help(const std::string& title = colorize(color::fg_yellow,
"COMMANDS:"))
{
std::string result = "Commands:\n";
return std::accumulate(get_commands().begin(),
get_commands().end(),
result,
[](auto r, auto&& p) { return r + " " + p.first + "\n"; });
std::string result = title + "\n";
std::vector<std::string> commands(get_commands().size());
std::transform(get_commands().begin(),
get_commands().end(),
commands.begin(),
[](const auto& p) { return colorize(color::fg_green, p.first); });
std::sort(commands.begin(), commands.end());
return std::accumulate(commands.begin(), commands.end(), result, [](auto r, auto&& s) {
return r + " " + s + "\n";
});
}
void parse(argument_parser& ap)
{
std::string version_str = "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) +
"." + std::to_string(MIGRAPHX_VERSION_MINOR);
ap(wrong_commands, {}, ap.metavar("<command>"), ap.append());
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
ap(nullptr,
{"-v", "--version"},
ap.help("Show MIGraphX version"),
ap.show_help(version_str));
// Trim command off of exe name
ap.set_exe_name(ap.get_exe_name().substr(0, ap.get_exe_name().size() - 5));
ap.set_exe_name_to(exe_name);
}
void run() {}
std::vector<std::string> wrong_commands{};
std::string exe_name = "<exe>";
void run()
{
std::cout << color::fg_red << color::bold << "error: " << color::reset;
auto it = std::find_if(wrong_commands.begin(), wrong_commands.end(), [](const auto& c) {
return get_commands().count(c) > 0;
});
if(it == wrong_commands.end())
{
std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
<< "' is not a valid command." << std::endl;
std::cout << get_command_help("Available commands:") << std::endl;
}
else
{
std::cout << "command '" << color::fg_yellow << *it << color::reset
<< "' must be first argument" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " " << exe_name << " " << *it << " <options>" << std::endl;
}
std::cout << std::endl;
}
};
} // namespace MIGRAPHX_INLINE_NS
......@@ -619,11 +659,11 @@ int main(int argc, const char* argv[])
auto cmd = args.front();
if(m.count(cmd) > 0)
{
m.at(cmd)({args.begin() + 1, args.end()});
m.at(cmd)(argv[0], {args.begin() + 1, args.end()});
}
else
{
run_command<main_command>(args);
run_command<main_command>(argv[0], args);
}
return 0;
......
......@@ -74,6 +74,22 @@ void group_unique(Iterator start, Iterator last, Output out, Predicate pred)
}
}
template <class Iterator1, class Iterator2>
std::ptrdiff_t
levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterator2 last2)
{
if(first1 == last1)
return std::distance(first2, last2);
if(first2 == last2)
return std::distance(first1, last1);
if(*first1 == *first2)
return levenshtein_distance(std::next(first1), last1, std::next(first2), last2);
auto x1 = levenshtein_distance(std::next(first1), last1, std::next(first2), last2);
auto x2 = levenshtein_distance(first1, last1, std::next(first2), last2);
auto x3 = levenshtein_distance(std::next(first1), last1, first2, last2);
return std::ptrdiff_t{1} + std::min({x1, x2, x3});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -38,20 +38,34 @@ struct check_shapes
const shape* begin;
const shape* end;
const std::string name;
const bool dynamic_allowed;
check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n)
check_shapes(const shape* b, const shape* e, const std::string& n, const bool d = false)
: begin(b), end(e), name(n), dynamic_allowed(d)
{
check_dynamic();
}
template <class Op>
check_shapes(const shape* b, const shape* e, const Op& op) : begin(b), end(e), name(op.name())
check_shapes(const shape* b, const shape* e, const Op& op, const bool d = false)
: begin(b), end(e), name(op.name()), dynamic_allowed(d)
{
check_dynamic();
}
template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op)
: begin(s.data()), end(s.data() + s.size()), name(op.name())
check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
: begin(s.data()), end(s.data() + s.size()), name(op.name()), dynamic_allowed(d)
{
check_dynamic();
}
void check_dynamic() const
{
if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); }))
{
MIGRAPHX_THROW(prefix() + "Dynamic shapes not supported");
}
}
std::string prefix() const
......@@ -92,44 +106,62 @@ struct check_shapes
return *this;
}
/*!
* Check that the first shape has exactly n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& only_dims(std::size_t n) const
{
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end)
{
if(begin->lens().size() != n)
if(begin->max_lens().size() != n)
MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
}
return *this;
}
/*!
* Check that the first shape has a maximum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& max_ndims(std::size_t n) const
{
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end)
{
if(begin->lens().size() > n)
if(begin->max_lens().size() > n)
MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) +
" dimensions");
}
return *this;
}
/*!
* Check that the first shape has a minimum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& min_ndims(std::size_t n) const
{
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end)
{
if(begin->lens().size() < n)
if(begin->max_lens().size() < n)
MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) +
" dimensions");
}
return *this;
}
/*!
* Check all shapes have the same shape.
*/
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
......@@ -137,6 +169,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes have the same type.
*/
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
......@@ -144,20 +179,32 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes have the same lens.
*/
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
if(!this->same([](const shape& s) { return s.max_lens(); }))
MIGRAPHX_THROW(prefix() + "Dimensions do not match");
if(this->any_of([&](const shape& s) { return s.dynamic(); }))
if(!this->same([](const shape& s) { return s.min_lens(); }))
MIGRAPHX_THROW(prefix() + "Min dynamic dimensions do not match");
return *this;
}
/*!
* Check all shapes have the same number of dimensions.
*/
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.lens().size(); }))
if(!this->same([](const shape& s) { return s.max_lens().size(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this;
}
/*!
* Check all shapes are standard.
*/
const check_shapes& standard() const
{
if(!this->all_of([](const shape& s) { return s.standard(); }))
......@@ -165,6 +212,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are standard or scalar.
*/
const check_shapes& standard_or_scalar() const
{
if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
......@@ -172,6 +222,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are packed.
*/
const check_shapes& packed() const
{
if(!this->all_of([](const shape& s) { return s.packed(); }))
......@@ -179,6 +232,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are packed or broadcasted.
*/
const check_shapes& packed_or_broadcasted() const
{
if(!this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
......@@ -186,6 +242,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are tuples.
*/
const check_shapes& tuple_type() const
{
if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
......@@ -193,6 +252,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are not transposed.
*/
const check_shapes& not_transposed() const
{
if(!this->all_of([](const shape& s) { return not s.transposed(); }))
......@@ -200,6 +262,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are not broadcasted.
*/
const check_shapes& not_broadcasted() const
{
if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
......@@ -207,6 +272,10 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes have the same n elements.
* \param n number of elements
*/
const check_shapes& elements(std::size_t n) const
{
if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
......@@ -214,6 +283,9 @@ struct check_shapes
return *this;
}
/*!
* Check the batches of all the shapes do not have transposed strides.
*/
const check_shapes& batch_not_transposed() const
{
if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
......@@ -242,6 +314,16 @@ struct check_shapes
return std::all_of(begin, end, p);
}
template <class Predicate>
bool any_of(Predicate p) const
{
if(begin == end)
return false;
assert(begin != nullptr);
assert(end != nullptr);
return std::any_of(begin, end, p);
}
const shape* get(long i) const
{
if(i >= size())
......
......@@ -45,6 +45,11 @@ struct literal : raw_data<literal>
{
literal() {}
/*!
* Empty literal with a specific shape type
*/
explicit literal(shape::type_t shape_type) : m_shape(shape_type, {}) {}
template <class U, class T = deduce<U>, shape::type_t ShapeType = shape::get_type<T>{}>
literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(ShapeType)
{
......
......@@ -33,10 +33,20 @@ inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in onnx options to parser
struct onnx_options
{
/// default batch size to use (if not specified in onnx file)
std::size_t default_dim_value = 1;
/// Old way to set default fixed dimension size
std::size_t default_dim_value = 0;
/*!
* Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value
* set parser throws)
*/
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
/// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/*!
* Explicitly specify dynamic dims of an input (if both map_input_dims and
* map_dyn_input_dims set parser throws)
*/
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims = {};
/// Continue parsing onnx file if an unknown operator is found
bool skip_unknown_operators = false;
/// Print program if an error occurs
......
......@@ -24,17 +24,10 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_ABS_HPP
#define MIGRAPHX_GUARD_OPERATORS_ABS_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/make_signed.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_ACOS_HPP
#define MIGRAPHX_GUARD_OPERATORS_ACOS_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,16 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_ADD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/binary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,16 +24,10 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_ALLOCATE_HPP
#define MIGRAPHX_GUARD_OPERATORS_ALLOCATE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -26,12 +26,10 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -26,12 +26,10 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,16 +24,10 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_AS_SHAPE_HPP
#define MIGRAPHX_GUARD_OPERATORS_AS_SHAPE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
#include <migraphx/argument.hpp>
#include <migraphx/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_ASIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ASIN_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_ATAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ATAN_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP
#define MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -26,9 +26,7 @@
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
......
......@@ -24,14 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_BROADCAST_HPP
#define MIGRAPHX_GUARD_OPERATORS_BROADCAST_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,11 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/context.hpp>
......
......@@ -25,16 +25,12 @@
#define MIGRAPHX_GUARD_OPERATORS_CLIP_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <cmath>
#include <utility>
#include <limits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment