Commit 7e297b13 authored by Paul's avatar Paul
Browse files

Merge

parents 86ea5e91 aa7ff911
...@@ -15,7 +15,7 @@ struct module; ...@@ -15,7 +15,7 @@ struct module;
struct simplify_algebra struct simplify_algebra
{ {
std::string name() const { return "simplify_algebra"; } std::string name() const { return "simplify_algebra"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_QDQ_HPP
#define MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP #define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_QDQ_HPP
#include <string> #include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -11,12 +10,13 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,12 +10,13 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
/** /**
* Decompose operators. * Inserts quantized operators in place of dq->quantizable_op->q
* then removes remaining fake quantization (q->dq pairs)
*/ */
struct decompose struct simplify_qdq
{ {
std::string name() const { return "decompose"; } std::string name() const { return "simplify_qdq"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct module; ...@@ -16,7 +16,7 @@ struct module;
struct simplify_reshapes struct simplify_reshapes
{ {
std::string name() const { return "simplify_reshapes"; } std::string name() const { return "simplify_reshapes"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -36,20 +36,26 @@ struct stream_model ...@@ -36,20 +36,26 @@ struct stream_model
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct stream_model struct stream_model
* { {
* std::size_t get_nstream() const; //
* std::size_t get_stream(instruction_ref ins) const; std::size_t get_nstream() const;
* std::size_t get_event_id(instruction_ref ins) const; //
* bool has_stream(instruction_ref ins) const; std::size_t get_stream(instruction_ref ins) const;
* bool is_record(instruction_ref ins) const; //
* bool is_wait(instruction_ref ins) const; std::size_t get_event_id(instruction_ref ins) const;
* }; //
* bool has_stream(instruction_ref ins) const;
*/ //
bool is_record(instruction_ref ins) const;
//
bool is_wait(instruction_ref ins) const;
};
#else
struct stream_model struct stream_model
{ {
...@@ -296,6 +302,7 @@ inline const ValueType& any_cast(const stream_model& x) ...@@ -296,6 +302,7 @@ inline const ValueType& any_cast(const stream_model& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -18,7 +18,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -18,7 +18,7 @@ inline namespace MIGRAPHX_INLINE_NS {
template <class F> template <class F>
auto with_char(F f) auto with_char(F f)
{ {
return [=](unsigned char c) { return f(c); }; return [=](unsigned char c) -> bool { return f(c); };
} }
inline std::string inline std::string
...@@ -71,7 +71,7 @@ std::string trim(const std::string& s, F f) ...@@ -71,7 +71,7 @@ std::string trim(const std::string& s, F f)
{ {
auto start = std::find_if_not(s.begin(), s.end(), f); auto start = std::find_if_not(s.begin(), s.end(), f);
auto last = std::find_if_not(s.rbegin(), std::string::const_reverse_iterator(start), f).base(); auto last = std::find_if_not(s.rbegin(), std::string::const_reverse_iterator(start), f).base();
return std::string(start, last); return {start, last};
} }
inline std::string trim(const std::string& s) inline std::string trim(const std::string& s)
...@@ -120,22 +120,28 @@ interpolate_string(const std::string& input, F f, std::string start = "${", std: ...@@ -120,22 +120,28 @@ interpolate_string(const std::string& input, F f, std::string start = "${", std:
result.append(it, next_start); result.append(it, next_start);
if(next_start == input.end()) if(next_start == input.end())
break; break;
auto r = f(next_start + start.size(), next_end - end.size() + 1); auto r = f(next_start + start.size(), next_end);
result.append(r.begin(), r.end()); result.append(r.begin(), r.end());
it = next_end + 1; it = next_end + end.size();
} }
return result; return result;
} }
inline std::string interpolate_string(const std::string& input, inline std::string interpolate_string(const std::string& input,
const std::unordered_map<std::string, std::string>& vars) const std::unordered_map<std::string, std::string>& vars,
{ std::string start = "${",
return interpolate_string(input, [&](auto start, auto last) { std::string end = "}")
auto key = trim({start, last}); {
auto it = vars.find(key); return interpolate_string(
if(it == vars.end()) input,
throw std::runtime_error("Unknown key: " + key); [&](auto start_it, auto last_it) {
return it->second; auto key = trim({start_it, last_it});
}); auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
} }
template <class Iterator> template <class Iterator>
...@@ -163,7 +169,8 @@ inline std::string to_string_range(const std::initializer_list<T>& r) ...@@ -163,7 +169,8 @@ inline std::string to_string_range(const std::initializer_list<T>& r)
} }
template <class T> template <class T>
inline std::string to_string(const T& x) inline auto to_string(const T& x)
-> decltype((std::declval<std::stringstream>() << x), std::string{})
{ {
std::stringstream ss; std::stringstream ss;
ss << x; ss << x;
......
...@@ -82,20 +82,26 @@ argument copy_from_target(T&, const argument& arg) ...@@ -82,20 +82,26 @@ argument copy_from_target(T&, const argument& arg)
return arg; return arg;
} }
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct target struct target
* { {
* std::string name() const; //
* std::vector<pass> get_passes(context& ctx,const compile_options& options) const; std::string name() const;
* context get_context() const; //
* argument copy_to(const argument& input) const; std::vector<pass> get_passes(context& ctx, const compile_options& options) const;
* argument copy_from(const argument& input) const; //
* argument allocate(const shape& s) const; context get_context() const;
* }; // (optional)
* argument copy_to(const argument& input) const;
*/ // (optional)
argument copy_from(const argument& input) const;
// (optional)
argument allocate(const shape& s) const;
};
#else
struct target struct target
{ {
...@@ -382,6 +388,7 @@ inline const ValueType& any_cast(const target& x) ...@@ -382,6 +388,7 @@ inline const ValueType& any_cast(const target& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -120,10 +120,8 @@ struct tensor_view ...@@ -120,10 +120,8 @@ struct tensor_view
return m_data[m_shape.index(this->size() - 1)]; return m_data[m_shape.index(this->size() - 1)];
} }
// cppcheck-suppress functionConst
iterator begin() { return {0, {this}}; } iterator begin() { return {0, {this}}; }
// cppcheck-suppress functionConst
iterator end() { return {this->size(), {this}}; } iterator end() { return {this->size(), {this}}; }
const_iterator begin() const { return {0, {this}}; } const_iterator begin() const { return {0, {this}}; }
......
...@@ -178,6 +178,7 @@ struct value ...@@ -178,6 +178,7 @@ struct value
value(std::nullptr_t); value(std::nullptr_t);
value(const char* i); value(const char* i);
value(const std::string& pkey, const char* i);
#define MIGRAPHX_VALUE_GENERATE_DECL_METHODS(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_DECL_METHODS(vt, cpp_type) \
value(cpp_type i); \ value(cpp_type i); \
...@@ -188,6 +189,12 @@ struct value ...@@ -188,6 +189,12 @@ struct value
const cpp_type* if_##vt() const; const cpp_type* if_##vt() const;
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DECL_METHODS) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DECL_METHODS)
template <class T>
using literal_to_string = std::conditional_t<(std::is_convertible<T, const char*>{} and
std::is_convertible<T, std::string>{}),
std::string,
T>;
template <class T> template <class T>
using pick_numeric = std::conditional_t< using pick_numeric = std::conditional_t<
std::is_floating_point<T>{}, std::is_floating_point<T>{},
...@@ -246,6 +253,7 @@ struct value ...@@ -246,6 +253,7 @@ struct value
return *this = from_values(rhs); // NOLINT return *this = from_values(rhs); // NOLINT
} }
value& operator=(const char* c);
value& operator=(std::nullptr_t); value& operator=(std::nullptr_t);
value& operator=(const std::initializer_list<value>& i); value& operator=(const std::initializer_list<value>& i);
...@@ -315,8 +323,7 @@ struct value ...@@ -315,8 +323,7 @@ struct value
{ {
switch(this->get_type()) switch(this->get_type())
{ {
case null_type: case null_type: {
{
std::nullptr_t null{}; std::nullptr_t null{};
if(this->key.empty()) if(this->key.empty())
v(null); v(null);
...@@ -325,8 +332,7 @@ struct value ...@@ -325,8 +332,7 @@ struct value
return; return;
} }
#define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \
case vt##_type: \ case vt##_type: { \
{ \
if(this->key.empty()) \ if(this->key.empty()) \
v(this->get_##vt()); \ v(this->get_##vt()); \
else \ else \
...@@ -346,19 +352,17 @@ struct value ...@@ -346,19 +352,17 @@ struct value
{ {
switch(this->get_type()) switch(this->get_type())
{ {
case null_type: case null_type: {
{
std::nullptr_t null{}; std::nullptr_t null{};
v(null); v(null);
return; return;
} }
#define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \
case vt##_type: \ case vt##_type: { \
{ \
v(this->get_##vt()); \ v(this->get_##vt()); \
return; \ return; \
} }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, ) MIGRAPHX_VALUE_GENERATE_CASE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, ) MIGRAPHX_VALUE_GENERATE_CASE(object, )
} }
...@@ -374,11 +378,11 @@ struct value ...@@ -374,11 +378,11 @@ struct value
} }
template <class To> template <class To>
To value_or(const To& default_value) const literal_to_string<To> value_or(const To& default_value) const
{ {
if(this->is_null()) if(this->is_null())
return default_value; return default_value;
return to<To>(); return to<literal_to_string<To>>();
} }
template <class To> template <class To>
...@@ -394,12 +398,12 @@ struct value ...@@ -394,12 +398,12 @@ struct value
} }
template <class To> template <class To>
To get(const std::string& pkey, const To& default_value) const literal_to_string<To> get(const std::string& pkey, const To& default_value) const
{ {
const auto* v = find(pkey); const auto* v = find(pkey);
if(v == this->end()) if(v == this->end())
return default_value; return default_value;
return v->to<To>(); return v->to<literal_to_string<To>>();
} }
template <class To> template <class To>
...@@ -412,10 +416,11 @@ struct value ...@@ -412,10 +416,11 @@ struct value
} }
template <class To> template <class To>
std::vector<To> get(const std::string& pkey, std::vector<literal_to_string<To>> get(const std::string& pkey,
const std::initializer_list<To>& default_value) const const std::initializer_list<To>& default_value) const
{ {
return get<std::vector<To>>(pkey, default_value); return get(pkey,
std::vector<literal_to_string<To>>{default_value.begin(), default_value.end()});
} }
friend bool operator==(const value& x, const value& y); friend bool operator==(const value& x, const value& y);
...@@ -429,6 +434,8 @@ struct value ...@@ -429,6 +434,8 @@ struct value
void debug_print(bool show_type = false) const; void debug_print(bool show_type = false) const;
type_t get_type() const;
private: private:
template <class T> template <class T>
std::vector<value> from_values(const T& r) std::vector<value> from_values(const T& r)
...@@ -438,7 +445,6 @@ struct value ...@@ -438,7 +445,6 @@ struct value
r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); }); r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); });
return v; return v;
} }
type_t get_type() const;
std::shared_ptr<value_base_impl> x; std::shared_ptr<value_base_impl> x;
std::string key; std::string key;
}; };
......
...@@ -168,7 +168,6 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out ...@@ -168,7 +168,6 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
{ {
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance; double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
// cppcheck-suppress uninitvar
if(out_error != nullptr) if(out_error != nullptr)
*out_error = error; *out_error = error;
return error <= threshold; return error <= threshold;
......
...@@ -11,49 +11,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,49 +11,8 @@ inline namespace MIGRAPHX_INLINE_NS {
static void inline_submodule(module& m, instruction_ref ins, bool cond) static void inline_submodule(module& m, instruction_ref ins, bool cond)
{ {
const auto& mod_inputs = ins->module_inputs(); const auto& mod_inputs = ins->module_inputs();
const auto* smod = cond ? mod_inputs.at(0) : mod_inputs.at(1); module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
auto mod_outputs = m.insert_module_instructions(ins, smod);
std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*smod))
{
instruction_ref copy_ins{};
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = m.add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = m.add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = m.add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
mod_outputs = {copy_ins};
}
auto ins_outputs = ins->outputs(); auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size()); assert(mod_outputs.size() >= ins_outputs.size());
......
...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins, ...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m) static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{ {
auto op = any_cast<op::pooling>(ins->get_operator()); auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average") if(op.mode == op::pooling_mode::average)
{ {
return; return;
} }
......
...@@ -468,5 +468,11 @@ std::vector<shape> try_compute_shape(const operation& op, const std::vector<shap ...@@ -468,5 +468,11 @@ std::vector<shape> try_compute_shape(const operation& op, const std::vector<shap
} }
return {new_shape}; return {new_shape};
} }
migraphx::instruction* as_address(const instruction_ref& ins) noexcept
{
return std::addressof(*ins);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -133,6 +133,12 @@ std::string to_json_string(const value& val) ...@@ -133,6 +133,12 @@ std::string to_json_string(const value& val)
return j.dump(); return j.dump();
} }
std::string to_pretty_json_string(const value& val, std::size_t indent)
{
json j = val;
return j.dump(indent);
}
migraphx::value from_json_string(const char* str, std::size_t size) migraphx::value from_json_string(const char* str, std::size_t size)
{ {
json j = json::parse(str, str + size); json j = json::parse(str, str + size);
......
...@@ -5,20 +5,41 @@ namespace migraphx { ...@@ -5,20 +5,41 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); } operation make_op(const std::string& name) { return load_op(name); }
operation make_op(const std::string& name, const value& v)
template <class F>
operation make_op_generic(const std::string& name, F for_each)
{ {
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object");
auto op = load_op(name); auto op = load_op(name);
// Merge values // Merge values
value w = op.to_value(); value w = op.to_value();
for(auto&& x : v) for_each([&](const auto& key, const auto& x) {
{ if(not w.contains(key))
w.at(x.get_key()) = x.without_key(); // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
} MIGRAPHX_THROW("No key '" + key + "' in " + name);
w.at(key) = x;
});
op.from_value(w); op.from_value(w);
return op; return op;
} }
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v)
{
return make_op_generic(name, [&](auto f) {
for(auto&& [key, x] : v)
f(key, x);
});
}
operation make_op_from_value(const std::string& name, const value& v)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object for make_op: " + name);
return make_op_generic(name, [&](auto f) {
for(auto&& x : v)
f(x.get_key(), x.without_key());
});
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#include <iterator>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -21,6 +22,8 @@ ...@@ -21,6 +22,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_FINALIZE)
struct module_impl struct module_impl
{ {
// A list is used to keep references to an instruction stable // A list is used to keep references to an instruction stable
...@@ -28,6 +31,7 @@ struct module_impl ...@@ -28,6 +31,7 @@ struct module_impl
std::unordered_set<instruction*> instruction_set; std::unordered_set<instruction*> instruction_set;
std::string name; std::string name;
uint32_t nparams = 0; uint32_t nparams = 0;
bool bypass = false;
bool contains(instruction_ref ins) const bool contains(instruction_ref ins) const
{ {
...@@ -49,6 +53,13 @@ struct module_impl ...@@ -49,6 +53,13 @@ struct module_impl
return emplace(pos, ins); return emplace(pos, ins);
} }
void clear()
{
instructions.clear();
instruction_set.clear();
nparams = 0;
}
void push_front(const instruction& ins) { insert(instructions.begin(), ins); } void push_front(const instruction& ins) { insert(instructions.begin(), ins); }
void push_back(const instruction& ins) { insert(instructions.end(), ins); } void push_back(const instruction& ins) { insert(instructions.end(), ins); }
...@@ -100,18 +111,21 @@ module& module::operator=(module m) ...@@ -100,18 +111,21 @@ module& module::operator=(module m)
std::string module::name() const { return impl->name; } std::string module::name() const { return impl->name; }
bool module::bypass() const { return impl->bypass; }
void module::set_bypass(bool b) { impl->bypass = b; }
void module::assign(const module& m) void module::assign(const module& m)
{ {
// clean the current module // copy the impl
if(!impl) if(!impl)
{
impl = std::make_unique<module_impl>(); impl = std::make_unique<module_impl>();
} *impl = *m.impl;
else if(!impl->instructions.empty())
// clear instructions
if(!impl->instructions.empty())
{ {
impl->instructions.clear(); impl->clear();
} }
impl->name = m.impl->name;
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
...@@ -125,9 +139,10 @@ void module::assign(const module& m) ...@@ -125,9 +139,10 @@ void module::assign(const module& m)
else if(ins->name() == "@param") else if(ins->name() == "@param")
{ {
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter; auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto order = any_cast<builtin::param>(ins->get_operator()).order;
auto s = ins->get_shape(); auto s = ins->get_shape();
copy_ins = copy_ins = impl->insert(impl->instructions.end(),
impl->insert(impl->instructions.end(), {builtin::param{name}, std::move(s), {}}); {builtin::param{name, order}, std::move(s), {}});
} }
else if(ins->name() == "@outline") else if(ins->name() == "@outline")
{ {
...@@ -166,6 +181,7 @@ instruction_ref module::insert_instruction(instruction_ref ins, ...@@ -166,6 +181,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
const operation& op, const operation& op,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
assert(has_instruction(ins) or is_end(ins, this->end()));
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
auto result = impl->insert(ins, {op, r, std::move(args)}); auto result = impl->insert(ins, {op, r, std::move(args)});
...@@ -187,6 +203,7 @@ instruction_ref module::insert_instruction(instruction_ref ins, ...@@ -187,6 +203,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
std::vector<instruction_ref> args, std::vector<instruction_ref> args,
std::vector<module_ref> module_args) std::vector<module_ref> module_args)
{ {
assert(has_instruction(ins) or is_end(ins, this->end()));
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args); auto out_shape = compute_shape(op, args, module_args);
auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)}); auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)});
...@@ -199,6 +216,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, ...@@ -199,6 +216,7 @@ instruction_ref module::replace_instruction(instruction_ref ins,
const operation& op, const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST
{ {
assert(has_instruction(ins));
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
...@@ -212,6 +230,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, ...@@ -212,6 +230,7 @@ instruction_ref module::replace_instruction(instruction_ref ins,
std::vector<instruction_ref> args, std::vector<instruction_ref> args,
std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST
{ {
assert(has_instruction(ins));
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args); auto out_shape = compute_shape(op, args, module_args);
instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args));
...@@ -278,6 +297,8 @@ instruction_ref module::remove_instructions(instruction_ref first, instruction_r ...@@ -278,6 +297,8 @@ instruction_ref module::remove_instructions(instruction_ref first, instruction_r
instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst) instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst)
{ {
assert(has_instruction(src));
assert(has_instruction(dst) or is_end(dst, this->end()));
impl->instructions.splice(dst, impl->instructions, src); impl->instructions.splice(dst, impl->instructions, src);
return src; return src;
} }
...@@ -290,6 +311,55 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d ...@@ -290,6 +311,55 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d
return src; return src;
} }
std::vector<instruction_ref> module::insert_module_instructions(
instruction_ref ins, module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*m))
{
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = this->add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = this->add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = this->add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = this->insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty())
mod_outputs = {map_ins.at(std::prev(m->end()))};
return mod_outputs;
}
instruction_ref module::add_literal(literal l) instruction_ref module::add_literal(literal l)
{ {
impl->emplace_front(std::move(l)); impl->emplace_front(std::move(l));
...@@ -320,6 +390,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args) ...@@ -320,6 +390,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args)
return result; return result;
} }
instruction_ref module::replace_return(std::vector<instruction_ref> args)
{
auto last = std::prev(this->end());
// If there is no return then add a return
if(last->name() != "@return")
return this->add_return(args);
shape r = compute_shape(last->get_operator(), args);
instruction::replace(last, last->get_operator(), r, std::move(args));
assert(last->valid(begin()));
return last;
}
shape module::get_parameter_shape(std::string name) const shape module::get_parameter_shape(std::string name) const
{ {
auto ins = std::find_if( auto ins = std::find_if(
...@@ -334,7 +418,6 @@ shape module::get_parameter_shape(std::string name) const ...@@ -334,7 +418,6 @@ shape module::get_parameter_shape(std::string name) const
} }
}); });
if(ins != this->end()) if(ins != this->end())
return ins->get_shape(); return ins->get_shape();
else else
return {}; return {};
...@@ -430,7 +513,6 @@ instruction_ref module::validate() const ...@@ -430,7 +513,6 @@ instruction_ref module::validate() const
bool check_order = std::all_of(inputs.begin(), inputs.end(), [&](auto in) { bool check_order = std::all_of(inputs.begin(), inputs.end(), [&](auto in) {
return contains(impl->instructions, *in); return contains(impl->instructions, *in);
}); });
return !i.valid(impl->instructions.begin(), check_order); return !i.valid(impl->instructions.begin(), check_order);
}); });
} }
...@@ -473,8 +555,14 @@ instruction_ref module::find_dangling_reference() const ...@@ -473,8 +555,14 @@ instruction_ref module::find_dangling_reference() const
void module::finalize(context& ctx) void module::finalize(context& ctx)
{ {
const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{});
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
if(trace)
{
std::cout << "Finalize: ";
this->debug_print(ins);
}
ins->finalize(ctx); ins->finalize(ctx);
for(const auto& smod : ins->module_inputs()) for(const auto& smod : ins->module_inputs())
{ {
...@@ -547,8 +635,9 @@ std::unordered_map<instruction_ref, std::string> module::print( ...@@ -547,8 +635,9 @@ std::unordered_map<instruction_ref, std::string> module::print(
var_name = this->name(); var_name = this->name();
var_name.append((this->name().empty() ? "@" : ":@")); var_name.append((this->name().empty() ? "@" : ":@"));
var_name.append(std::to_string(count)); var_name.append(std::to_string(count));
count++;
} }
// count every instruction so index matches loc in the printout program
count++;
names.emplace(ins, var_name); names.emplace(ins, var_name);
print_func(ins, names); print_func(ins, names);
...@@ -648,7 +737,6 @@ std::unordered_map<instruction_ref, std::string> ...@@ -648,7 +737,6 @@ std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const
{ {
os << "migraphx::module p;" << std::endl; os << "migraphx::module p;" << std::endl;
// cppcheck-suppress variableScope
unsigned long seed = 0; unsigned long seed = 0;
names = this->print( names = this->print(
[&](auto ins, auto ins_names) { [&](auto ins, auto ins_names) {
......
...@@ -14,44 +14,36 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -14,44 +14,36 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{ {
switch(o.type) switch(o.type)
{ {
case msgpack::type::NIL: case msgpack::type::NIL: {
{
v = nullptr; v = nullptr;
break; break;
} }
case msgpack::type::BOOLEAN: case msgpack::type::BOOLEAN: {
{
v = o.as<bool>(); v = o.as<bool>();
break; break;
} }
case msgpack::type::POSITIVE_INTEGER: case msgpack::type::POSITIVE_INTEGER: {
{
v = o.as<std::uint64_t>(); v = o.as<std::uint64_t>();
break; break;
} }
case msgpack::type::NEGATIVE_INTEGER: case msgpack::type::NEGATIVE_INTEGER: {
{
v = o.as<std::int64_t>(); v = o.as<std::int64_t>();
break; break;
} }
case msgpack::type::FLOAT32: case msgpack::type::FLOAT32:
case msgpack::type::FLOAT64: case msgpack::type::FLOAT64: {
{
v = o.as<double>(); v = o.as<double>();
break; break;
} }
case msgpack::type::STR: case msgpack::type::STR: {
{
v = o.as<std::string>(); v = o.as<std::string>();
break; break;
} }
case msgpack::type::BIN: case msgpack::type::BIN: {
{
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size}; v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break; break;
} }
case msgpack::type::ARRAY: case msgpack::type::ARRAY: {
{
migraphx::value r = migraphx::value::array{}; migraphx::value r = migraphx::value::array{};
std::for_each( std::for_each(
o.via.array.ptr, o.via.array.ptr,
...@@ -60,8 +52,7 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -60,8 +52,7 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r; v = r;
break; break;
} }
case msgpack::type::MAP: case msgpack::type::MAP: {
{
migraphx::value r = migraphx::value::object{}; migraphx::value r = migraphx::value::object{};
std::for_each(o.via.map.ptr, std::for_each(o.via.map.ptr,
o.via.map.ptr + o.via.map.size, o.via.map.ptr + o.via.map.size,
...@@ -71,7 +62,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -71,7 +62,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r; v = r;
break; break;
} }
case msgpack::type::EXT: { MIGRAPHX_THROW("msgpack EXT type not supported."); case msgpack::type::EXT: {
MIGRAPHX_THROW("msgpack EXT type not supported.");
} }
} }
return o; return o;
......
...@@ -20,7 +20,7 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -20,7 +20,7 @@ auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<std::size_t>& lens) const std::vector<std::size_t>& lens)
{ {
std::vector<int64_t> result(vec); std::vector<int64_t> result(vec);
int64_t n_rank = static_cast<int64_t>(lens.size()); int64_t n_rank = lens.size();
std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>(); std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>();
if(contains(vec_attrs, op::normalize_attribute::use_output)) if(contains(vec_attrs, op::normalize_attribute::use_output))
{ {
......
...@@ -7,7 +7,7 @@ target_compile_options(onnx-proto PRIVATE -w) ...@@ -7,7 +7,7 @@ target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
file(GLOB ONNX_SRCS *.cpp) file(GLOB ONNX_SRCS ${CONFIGURE_DEPENDS} *.cpp)
add_library(migraphx_onnx ${ONNX_SRCS}) add_library(migraphx_onnx ${ONNX_SRCS})
target_include_directories(migraphx_onnx PRIVATE include) target_include_directories(migraphx_onnx PRIVATE include)
set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx) set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
......
...@@ -32,9 +32,20 @@ struct onnx_parser ...@@ -32,9 +32,20 @@ struct onnx_parser
instruction_ref add_bias(const std::vector<instruction_ref>& args, instruction_ref add_bias(const std::vector<instruction_ref>& args,
instruction_ref curr_ins, instruction_ref curr_ins,
uint64_t axis) const; uint64_t axis) const;
instruction_ref add_broadcastable_binary_op(const std::string& op_name, instruction_ref add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0, instruction_ref arg0,
instruction_ref arg1) const; instruction_ref arg1) const;
instruction_ref add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const;
template <class... Ts>
instruction_ref add_common_op(const std::string& op_name, Ts... xs) const
{
return add_common_op(op_name, {xs...});
}
instruction_ref add_instruction(const operation& op, instruction_ref add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const; const std::vector<instruction_ref>& args) const;
...@@ -63,6 +74,7 @@ struct onnx_parser ...@@ -63,6 +74,7 @@ struct onnx_parser
std::size_t default_dim_value = 1; std::size_t default_dim_value = 1;
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10;
int64_t opset_version = 13; int64_t opset_version = 13;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
......
#include <migraphx/onnx/onnx_parser.hpp> #include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <unordered_map> #include <unordered_map>
...@@ -20,6 +21,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -20,6 +21,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
parser.map_input_dims = options.map_input_dims; parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value; parser.default_dim_value = options.default_dim_value;
parser.skip_unknown_operators = options.skip_unknown_operators; parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations;
if(options.print_program_on_error) if(options.print_program_on_error)
{ {
...@@ -57,5 +59,7 @@ program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options ...@@ -57,5 +59,7 @@ program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options
return parse_onnx_from(options, data, size); return parse_onnx_from(options, data, size);
} }
std::vector<std::string> get_onnx_operators() { return onnx::get_op_parsers(); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment