Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
...@@ -35,7 +35,7 @@ struct shape ...@@ -35,7 +35,7 @@ struct shape
m(int64_type, int64_t) \ m(int64_type, int64_t) \
m(uint32_type, uint32_t) \ m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) m(uint64_type, uint64_t)
// clang-format on // clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
...@@ -131,6 +131,8 @@ struct shape ...@@ -131,6 +131,8 @@ struct shape
shape with_lens(type_t t, const std::vector<std::size_t>& l) const; shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
shape with_lens(const std::vector<std::size_t>& l) const; shape with_lens(const std::vector<std::size_t>& l) const;
shape with_type(type_t t) const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x); friend std::ostream& operator<<(std::ostream& os, const shape& x);
...@@ -186,8 +188,7 @@ struct shape ...@@ -186,8 +188,7 @@ struct shape
{ {
switch(t) switch(t)
{ {
case tuple_type: case tuple_type: {
{
tv(); tv();
return; return;
} }
...@@ -224,10 +225,11 @@ struct shape ...@@ -224,10 +225,11 @@ struct shape
const std::vector<shape>& sub_shapes() const; const std::vector<shape>& sub_shapes() const;
std::size_t element_space() const;
private: private:
shape(std::shared_ptr<shape_impl> pimpl);
std::shared_ptr<const shape_impl> impl; std::shared_ptr<const shape_impl> impl;
std::size_t element_space() const;
}; };
void migraphx_to_value(value& v, const shape& s); void migraphx_to_value(value& v, const shape& s);
......
...@@ -15,7 +15,7 @@ struct module; ...@@ -15,7 +15,7 @@ struct module;
struct simplify_algebra struct simplify_algebra
{ {
std::string name() const { return "simplify_algebra"; } std::string name() const { return "simplify_algebra"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct module; ...@@ -16,7 +16,7 @@ struct module;
struct simplify_reshapes struct simplify_reshapes
{ {
std::string name() const { return "simplify_reshapes"; } std::string name() const { return "simplify_reshapes"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -36,20 +36,26 @@ struct stream_model ...@@ -36,20 +36,26 @@ struct stream_model
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct stream_model struct stream_model
* { {
* std::size_t get_nstream() const; //
* std::size_t get_stream(instruction_ref ins) const; std::size_t get_nstream() const;
* std::size_t get_event_id(instruction_ref ins) const; //
* bool has_stream(instruction_ref ins) const; std::size_t get_stream(instruction_ref ins) const;
* bool is_record(instruction_ref ins) const; //
* bool is_wait(instruction_ref ins) const; std::size_t get_event_id(instruction_ref ins) const;
* }; //
* bool has_stream(instruction_ref ins) const;
*/ //
bool is_record(instruction_ref ins) const;
//
bool is_wait(instruction_ref ins) const;
};
#else
struct stream_model struct stream_model
{ {
...@@ -296,6 +302,7 @@ inline const ValueType& any_cast(const stream_model& x) ...@@ -296,6 +302,7 @@ inline const ValueType& any_cast(const stream_model& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -137,16 +137,17 @@ inline std::string interpolate_string(const std::string& input, ...@@ -137,16 +137,17 @@ inline std::string interpolate_string(const std::string& input,
std::string start = "${", std::string start = "${",
std::string end = "}") std::string end = "}")
{ {
return interpolate_string(input, return interpolate_string(
[&](auto start_it, auto last_it) { input,
auto key = trim({start_it, last_it}); [&](auto start_it, auto last_it) {
auto it = vars.find(key); auto key = trim({start_it, last_it});
if(it == vars.end()) auto it = vars.find(key);
throw std::runtime_error("Unknown key: " + key); if(it == vars.end())
return it->second; throw std::runtime_error("Unknown key: " + key);
}, return it->second;
std::move(start), },
std::move(end)); std::move(start),
std::move(end));
} }
template <class Iterator> template <class Iterator>
......
...@@ -82,20 +82,26 @@ argument copy_from_target(T&, const argument& arg) ...@@ -82,20 +82,26 @@ argument copy_from_target(T&, const argument& arg)
return arg; return arg;
} }
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct target struct target
* { {
* std::string name() const; //
* std::vector<pass> get_passes(context& ctx,const compile_options& options) const; std::string name() const;
* context get_context() const; //
* argument copy_to(const argument& input) const; std::vector<pass> get_passes(context& ctx, const compile_options& options) const;
* argument copy_from(const argument& input) const; //
* argument allocate(const shape& s) const; context get_context() const;
* }; // (optional)
* argument copy_to(const argument& input) const;
*/ // (optional)
argument copy_from(const argument& input) const;
// (optional)
argument allocate(const shape& s) const;
};
#else
struct target struct target
{ {
...@@ -382,6 +388,7 @@ inline const ValueType& any_cast(const target& x) ...@@ -382,6 +388,7 @@ inline const ValueType& any_cast(const target& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -120,10 +120,8 @@ struct tensor_view ...@@ -120,10 +120,8 @@ struct tensor_view
return m_data[m_shape.index(this->size() - 1)]; return m_data[m_shape.index(this->size() - 1)];
} }
// cppcheck-suppress functionConst
iterator begin() { return {0, {this}}; } iterator begin() { return {0, {this}}; }
// cppcheck-suppress functionConst
iterator end() { return {this->size(), {this}}; } iterator end() { return {this->size(), {this}}; }
const_iterator begin() const { return {0, {this}}; } const_iterator begin() const { return {0, {this}}; }
......
...@@ -178,6 +178,7 @@ struct value ...@@ -178,6 +178,7 @@ struct value
value(std::nullptr_t); value(std::nullptr_t);
value(const char* i); value(const char* i);
value(const std::string& pkey, const char* i);
#define MIGRAPHX_VALUE_GENERATE_DECL_METHODS(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_DECL_METHODS(vt, cpp_type) \
value(cpp_type i); \ value(cpp_type i); \
...@@ -188,6 +189,12 @@ struct value ...@@ -188,6 +189,12 @@ struct value
const cpp_type* if_##vt() const; const cpp_type* if_##vt() const;
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DECL_METHODS) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DECL_METHODS)
template <class T>
using literal_to_string = std::conditional_t<(std::is_convertible<T, const char*>{} and
std::is_convertible<T, std::string>{}),
std::string,
T>;
template <class T> template <class T>
using pick_numeric = std::conditional_t< using pick_numeric = std::conditional_t<
std::is_floating_point<T>{}, std::is_floating_point<T>{},
...@@ -246,6 +253,7 @@ struct value ...@@ -246,6 +253,7 @@ struct value
return *this = from_values(rhs); // NOLINT return *this = from_values(rhs); // NOLINT
} }
value& operator=(const char* c);
value& operator=(std::nullptr_t); value& operator=(std::nullptr_t);
value& operator=(const std::initializer_list<value>& i); value& operator=(const std::initializer_list<value>& i);
...@@ -315,8 +323,7 @@ struct value ...@@ -315,8 +323,7 @@ struct value
{ {
switch(this->get_type()) switch(this->get_type())
{ {
case null_type: case null_type: {
{
std::nullptr_t null{}; std::nullptr_t null{};
if(this->key.empty()) if(this->key.empty())
v(null); v(null);
...@@ -325,8 +332,7 @@ struct value ...@@ -325,8 +332,7 @@ struct value
return; return;
} }
#define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \
case vt##_type: \ case vt##_type: { \
{ \
if(this->key.empty()) \ if(this->key.empty()) \
v(this->get_##vt()); \ v(this->get_##vt()); \
else \ else \
...@@ -346,19 +352,17 @@ struct value ...@@ -346,19 +352,17 @@ struct value
{ {
switch(this->get_type()) switch(this->get_type())
{ {
case null_type: case null_type: {
{
std::nullptr_t null{}; std::nullptr_t null{};
v(null); v(null);
return; return;
} }
#define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \
case vt##_type: \ case vt##_type: { \
{ \
v(this->get_##vt()); \ v(this->get_##vt()); \
return; \ return; \
} }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, ) MIGRAPHX_VALUE_GENERATE_CASE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, ) MIGRAPHX_VALUE_GENERATE_CASE(object, )
} }
...@@ -374,11 +378,11 @@ struct value ...@@ -374,11 +378,11 @@ struct value
} }
template <class To> template <class To>
To value_or(const To& default_value) const literal_to_string<To> value_or(const To& default_value) const
{ {
if(this->is_null()) if(this->is_null())
return default_value; return default_value;
return to<To>(); return to<literal_to_string<To>>();
} }
template <class To> template <class To>
...@@ -394,12 +398,12 @@ struct value ...@@ -394,12 +398,12 @@ struct value
} }
template <class To> template <class To>
To get(const std::string& pkey, const To& default_value) const literal_to_string<To> get(const std::string& pkey, const To& default_value) const
{ {
const auto* v = find(pkey); const auto* v = find(pkey);
if(v == this->end()) if(v == this->end())
return default_value; return default_value;
return v->to<To>(); return v->to<literal_to_string<To>>();
} }
template <class To> template <class To>
...@@ -412,10 +416,11 @@ struct value ...@@ -412,10 +416,11 @@ struct value
} }
template <class To> template <class To>
std::vector<To> get(const std::string& pkey, std::vector<literal_to_string<To>> get(const std::string& pkey,
const std::initializer_list<To>& default_value) const const std::initializer_list<To>& default_value) const
{ {
return get<std::vector<To>>(pkey, default_value); return get(pkey,
std::vector<literal_to_string<To>>{default_value.begin(), default_value.end()});
} }
friend bool operator==(const value& x, const value& y); friend bool operator==(const value& x, const value& y);
...@@ -429,6 +434,8 @@ struct value ...@@ -429,6 +434,8 @@ struct value
void debug_print(bool show_type = false) const; void debug_print(bool show_type = false) const;
type_t get_type() const;
private: private:
template <class T> template <class T>
std::vector<value> from_values(const T& r) std::vector<value> from_values(const T& r)
...@@ -438,7 +445,6 @@ struct value ...@@ -438,7 +445,6 @@ struct value
r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); }); r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); });
return v; return v;
} }
type_t get_type() const;
std::shared_ptr<value_base_impl> x; std::shared_ptr<value_base_impl> x;
std::string key; std::string key;
}; };
......
...@@ -168,7 +168,6 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out ...@@ -168,7 +168,6 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
{ {
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance; double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
// cppcheck-suppress uninitvar
if(out_error != nullptr) if(out_error != nullptr)
*out_error = error; *out_error = error;
return error <= threshold; return error <= threshold;
......
...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins, ...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m) static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{ {
auto op = any_cast<op::pooling>(ins->get_operator()); auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average") if(op.mode == op::pooling_mode::average)
{ {
return; return;
} }
......
...@@ -133,6 +133,12 @@ std::string to_json_string(const value& val) ...@@ -133,6 +133,12 @@ std::string to_json_string(const value& val)
return j.dump(); return j.dump();
} }
std::string to_pretty_json_string(const value& val, std::size_t indent)
{
json j = val;
return j.dump(indent);
}
migraphx::value from_json_string(const char* str, std::size_t size) migraphx::value from_json_string(const char* str, std::size_t size)
{ {
json j = json::parse(str, str + size); json j = json::parse(str, str + size);
......
...@@ -5,20 +5,41 @@ namespace migraphx { ...@@ -5,20 +5,41 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); } operation make_op(const std::string& name) { return load_op(name); }
operation make_op(const std::string& name, const value& v)
template <class F>
operation make_op_generic(const std::string& name, F for_each)
{ {
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object");
auto op = load_op(name); auto op = load_op(name);
// Merge values // Merge values
value w = op.to_value(); value w = op.to_value();
for(auto&& x : v) for_each([&](const auto& key, const auto& x) {
{ if(not w.contains(key))
w.at(x.get_key()) = x.without_key(); // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
} MIGRAPHX_THROW("No key '" + key + "' in " + name);
w.at(key) = x;
});
op.from_value(w); op.from_value(w);
return op; return op;
} }
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v)
{
return make_op_generic(name, [&](auto f) {
for(auto&& [key, x] : v)
f(key, x);
});
}
operation make_op_from_value(const std::string& name, const value& v)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object for make_op: " + name);
return make_op_generic(name, [&](auto f) {
for(auto&& x : v)
f(x.get_key(), x.without_key());
});
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_FINALIZE)
struct module_impl struct module_impl
{ {
// A list is used to keep references to an instruction stable // A list is used to keep references to an instruction stable
...@@ -554,8 +556,14 @@ instruction_ref module::find_dangling_reference() const ...@@ -554,8 +556,14 @@ instruction_ref module::find_dangling_reference() const
void module::finalize(context& ctx) void module::finalize(context& ctx)
{ {
const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{});
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
if(trace)
{
std::cout << "Finalize: ";
this->debug_print(ins);
}
ins->finalize(ctx); ins->finalize(ctx);
for(const auto& smod : ins->module_inputs()) for(const auto& smod : ins->module_inputs())
{ {
...@@ -628,8 +636,9 @@ std::unordered_map<instruction_ref, std::string> module::print( ...@@ -628,8 +636,9 @@ std::unordered_map<instruction_ref, std::string> module::print(
var_name = this->name(); var_name = this->name();
var_name.append((this->name().empty() ? "@" : ":@")); var_name.append((this->name().empty() ? "@" : ":@"));
var_name.append(std::to_string(count)); var_name.append(std::to_string(count));
count++;
} }
// count every instruction so index matches loc in the printout program
count++;
names.emplace(ins, var_name); names.emplace(ins, var_name);
print_func(ins, names); print_func(ins, names);
...@@ -719,7 +728,6 @@ module::print_cpp(std::ostream& os, ...@@ -719,7 +728,6 @@ module::print_cpp(std::ostream& os,
const std::string& mname, const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const std::unordered_map<instruction_ref, std::string> names) const
{ {
// cppcheck-suppress variableScope // cppcheck-suppress variableScope
unsigned long seed = names.size(); unsigned long seed = names.size();
auto last = std::prev(this->end()); auto last = std::prev(this->end());
......
...@@ -14,44 +14,36 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -14,44 +14,36 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{ {
switch(o.type) switch(o.type)
{ {
case msgpack::type::NIL: case msgpack::type::NIL: {
{
v = nullptr; v = nullptr;
break; break;
} }
case msgpack::type::BOOLEAN: case msgpack::type::BOOLEAN: {
{
v = o.as<bool>(); v = o.as<bool>();
break; break;
} }
case msgpack::type::POSITIVE_INTEGER: case msgpack::type::POSITIVE_INTEGER: {
{
v = o.as<std::uint64_t>(); v = o.as<std::uint64_t>();
break; break;
} }
case msgpack::type::NEGATIVE_INTEGER: case msgpack::type::NEGATIVE_INTEGER: {
{
v = o.as<std::int64_t>(); v = o.as<std::int64_t>();
break; break;
} }
case msgpack::type::FLOAT32: case msgpack::type::FLOAT32:
case msgpack::type::FLOAT64: case msgpack::type::FLOAT64: {
{
v = o.as<double>(); v = o.as<double>();
break; break;
} }
case msgpack::type::STR: case msgpack::type::STR: {
{
v = o.as<std::string>(); v = o.as<std::string>();
break; break;
} }
case msgpack::type::BIN: case msgpack::type::BIN: {
{
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size}; v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break; break;
} }
case msgpack::type::ARRAY: case msgpack::type::ARRAY: {
{
migraphx::value r = migraphx::value::array{}; migraphx::value r = migraphx::value::array{};
std::for_each( std::for_each(
o.via.array.ptr, o.via.array.ptr,
...@@ -60,8 +52,7 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -60,8 +52,7 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r; v = r;
break; break;
} }
case msgpack::type::MAP: case msgpack::type::MAP: {
{
migraphx::value r = migraphx::value::object{}; migraphx::value r = migraphx::value::object{};
std::for_each(o.via.map.ptr, std::for_each(o.via.map.ptr,
o.via.map.ptr + o.via.map.size, o.via.map.ptr + o.via.map.size,
...@@ -71,7 +62,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -71,7 +62,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r; v = r;
break; break;
} }
case msgpack::type::EXT: { MIGRAPHX_THROW("msgpack EXT type not supported."); case msgpack::type::EXT: {
MIGRAPHX_THROW("msgpack EXT type not supported.");
} }
} }
return o; return o;
......
...@@ -7,7 +7,7 @@ target_compile_options(onnx-proto PRIVATE -w) ...@@ -7,7 +7,7 @@ target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
file(GLOB ONNX_SRCS *.cpp) file(GLOB ONNX_SRCS ${CONFIGURE_DEPENDS} *.cpp)
add_library(migraphx_onnx ${ONNX_SRCS}) add_library(migraphx_onnx ${ONNX_SRCS})
target_include_directories(migraphx_onnx PRIVATE include) target_include_directories(migraphx_onnx PRIVATE include)
set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx) set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
......
...@@ -32,9 +32,20 @@ struct onnx_parser ...@@ -32,9 +32,20 @@ struct onnx_parser
instruction_ref add_bias(const std::vector<instruction_ref>& args, instruction_ref add_bias(const std::vector<instruction_ref>& args,
instruction_ref curr_ins, instruction_ref curr_ins,
uint64_t axis) const; uint64_t axis) const;
instruction_ref add_broadcastable_binary_op(const std::string& op_name, instruction_ref add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0, instruction_ref arg0,
instruction_ref arg1) const; instruction_ref arg1) const;
instruction_ref add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const;
template <class... Ts>
instruction_ref add_common_op(const std::string& op_name, Ts... xs) const
{
return add_common_op(op_name, {xs...});
}
instruction_ref add_instruction(const operation& op, instruction_ref add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const; const std::vector<instruction_ref>& args) const;
......
...@@ -70,12 +70,14 @@ static literal from_repeated(shape::type_t t, const T& r) ...@@ -70,12 +70,14 @@ static literal from_repeated(shape::type_t t, const T& r)
instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const
{ {
if(ins->get_shape().standard()) auto attr = ins->get_operator().to_value();
std::string key = "require_std_shape";
if((attr.get(key, false)) or (not ins->get_shape().standard()))
{ {
return ins; return add_instruction(make_op("contiguous"), ins);
} }
return add_instruction(make_op("contiguous"), ins); return ins;
} }
instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_ref>& args, instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_ref>& args,
...@@ -96,7 +98,13 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s ...@@ -96,7 +98,13 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
instruction_ref arg0, instruction_ref arg0,
instruction_ref arg1) const instruction_ref arg1) const
{ {
return add_common_op(*mod, make_op(op_name), {arg0, arg1}); return this->add_common_op(op_name, arg0, arg1);
}
instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const
{
return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs));
} }
instruction_ref instruction_ref
...@@ -380,8 +388,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const ...@@ -380,8 +388,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data()); case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::UINT64: case onnx::TensorProto::UINT64:
return create_literal(shape::uint64_type, dims, t.uint64_data()); return create_literal(shape::uint64_type, dims, t.uint64_data());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16: {
{
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half; std::vector<half> data_half;
std::transform(data_uint16.begin(), std::transform(data_uint16.begin(),
...@@ -451,7 +458,8 @@ shape::type_t get_type(int dtype) ...@@ -451,7 +458,8 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type; case 11: return shape::double_type;
case 12: return shape::uint32_type; case 12: return shape::uint32_type;
case 13: return shape::uint64_type; case 13: return shape::uint64_type;
default: { MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported"); default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
} }
} }
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/pad_calc.hpp> #include <migraphx/pad_calc.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -94,7 +95,7 @@ void tune_padding_size(const value& v, ...@@ -94,7 +95,7 @@ void tune_padding_size(const value& v,
std::vector<int64_t>& s_start) std::vector<int64_t>& s_start)
{ {
// maxpooling or count_include_pad is 1, no change is required. // maxpooling or count_include_pad is 1, no change is required.
if(v.at("mode").to<std::string>() == "max" or count_include_pad == 1) if(v.at("mode").to<op::pooling_mode>() == op::pooling_mode::max or count_include_pad == 1)
{ {
return; return;
} }
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_celu : op_parser<parse_celu>
{
std::vector<op_desc> operators() const { return {{"Celu"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0;
if(contains(info.attributes, "alpha"))
{
alpha = info.attributes.at("alpha").f();
}
if(float_equal(alpha, 0.0f))
{
MIGRAPHX_THROW("CELU: alpha is zero (division by zero)");
}
auto input_lens = args[0]->get_shape().lens();
auto input_type = args[0]->get_shape().type();
if(input_type != migraphx::shape::float_type)
{
MIGRAPHX_THROW("CELU: input tensor not float type");
}
auto zero_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto alpha_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto linear_part = info.add_instruction(migraphx::make_op("max"), zero_lit, args[0]);
auto divi = info.add_instruction(migraphx::make_op("div"), args[0], alpha_lit);
auto expo = info.add_instruction(migraphx::make_op("exp"), divi);
auto sub = info.add_instruction(migraphx::make_op("sub"), expo, one_lit);
auto mul = info.add_instruction(migraphx::make_op("mul"), alpha_lit, sub);
auto exp_part = info.add_instruction(migraphx::make_op("min"), zero_lit, mul);
return info.add_instruction(migraphx::make_op("add"), linear_part, exp_part);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -16,7 +16,6 @@ struct parse_clip : op_parser<parse_clip> ...@@ -16,7 +16,6 @@ struct parse_clip : op_parser<parse_clip>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto input_lens = args[0]->get_shape().lens();
instruction_ref min_arg; instruction_ref min_arg;
instruction_ref max_arg; instruction_ref max_arg;
bool min_used = false; bool min_used = false;
...@@ -45,29 +44,17 @@ struct parse_clip : op_parser<parse_clip> ...@@ -45,29 +44,17 @@ struct parse_clip : op_parser<parse_clip>
max_used = true; max_used = true;
} }
if(min_used)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
min_arg);
}
if(max_used)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
max_arg);
}
if(min_used and max_used) if(min_used and max_used)
{ {
return info.add_instruction(make_op("clip"), args[0], min_arg, max_arg); return info.add_common_op("clip", args[0], min_arg, max_arg);
} }
else if(max_used) else if(max_used)
{ {
return info.add_instruction(make_op("min"), args[0], max_arg); return info.add_broadcastable_binary_op("min", args[0], max_arg);
} }
else if(min_used) else if(min_used)
{ {
return info.add_instruction(make_op("max"), args[0], min_arg); return info.add_broadcastable_binary_op("max", args[0], min_arg);
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment