"...lm-evaluation-harness.git" did not exist on "8dbd24f6603fa219991f4621545f0ca0edc30709"
Commit e2eb6036 authored by Paul's avatar Paul
Browse files

Merge

parents 298c93d5 1e0bbd78
...@@ -36,20 +36,26 @@ struct stream_model ...@@ -36,20 +36,26 @@ struct stream_model
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct stream_model struct stream_model
* { {
* std::size_t get_nstream() const; //
* std::size_t get_stream(instruction_ref ins) const; std::size_t get_nstream() const;
* std::size_t get_event_id(instruction_ref ins) const; //
* bool has_stream(instruction_ref ins) const; std::size_t get_stream(instruction_ref ins) const;
* bool is_record(instruction_ref ins) const; //
* bool is_wait(instruction_ref ins) const; std::size_t get_event_id(instruction_ref ins) const;
* }; //
* bool has_stream(instruction_ref ins) const;
*/ //
bool is_record(instruction_ref ins) const;
//
bool is_wait(instruction_ref ins) const;
};
#else
struct stream_model struct stream_model
{ {
...@@ -296,6 +302,7 @@ inline const ValueType& any_cast(const stream_model& x) ...@@ -296,6 +302,7 @@ inline const ValueType& any_cast(const stream_model& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -131,16 +131,17 @@ inline std::string interpolate_string(const std::string& input, ...@@ -131,16 +131,17 @@ inline std::string interpolate_string(const std::string& input,
std::string start = "${", std::string start = "${",
std::string end = "}") std::string end = "}")
{ {
return interpolate_string(input, return interpolate_string(
[&](auto start_it, auto last_it) { input,
auto key = trim({start_it, last_it}); [&](auto start_it, auto last_it) {
auto it = vars.find(key); auto key = trim({start_it, last_it});
if(it == vars.end()) auto it = vars.find(key);
throw std::runtime_error("Unknown key: " + key); if(it == vars.end())
return it->second; throw std::runtime_error("Unknown key: " + key);
}, return it->second;
std::move(start), },
std::move(end)); std::move(start),
std::move(end));
} }
template <class Iterator> template <class Iterator>
......
...@@ -82,20 +82,26 @@ argument copy_from_target(T&, const argument& arg) ...@@ -82,20 +82,26 @@ argument copy_from_target(T&, const argument& arg)
return arg; return arg;
} }
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct target struct target
* { {
* std::string name() const; //
* std::vector<pass> get_passes(context& ctx,const compile_options& options) const; std::string name() const;
* context get_context() const; //
* argument copy_to(const argument& input) const; std::vector<pass> get_passes(context& ctx, const compile_options& options) const;
* argument copy_from(const argument& input) const; //
* argument allocate(const shape& s) const; context get_context() const;
* }; // (optional)
* argument copy_to(const argument& input) const;
*/ // (optional)
argument copy_from(const argument& input) const;
// (optional)
argument allocate(const shape& s) const;
};
#else
struct target struct target
{ {
...@@ -382,6 +388,7 @@ inline const ValueType& any_cast(const target& x) ...@@ -382,6 +388,7 @@ inline const ValueType& any_cast(const target& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -178,6 +178,7 @@ struct value ...@@ -178,6 +178,7 @@ struct value
value(std::nullptr_t); value(std::nullptr_t);
value(const char* i); value(const char* i);
value(const std::string& pkey, const char* i);
#define MIGRAPHX_VALUE_GENERATE_DECL_METHODS(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_DECL_METHODS(vt, cpp_type) \
value(cpp_type i); \ value(cpp_type i); \
...@@ -188,6 +189,12 @@ struct value ...@@ -188,6 +189,12 @@ struct value
const cpp_type* if_##vt() const; const cpp_type* if_##vt() const;
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DECL_METHODS) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DECL_METHODS)
template <class T>
using literal_to_string = std::conditional_t<(std::is_convertible<T, const char*>{} and
std::is_convertible<T, std::string>{}),
std::string,
T>;
template <class T> template <class T>
using pick_numeric = std::conditional_t< using pick_numeric = std::conditional_t<
std::is_floating_point<T>{}, std::is_floating_point<T>{},
...@@ -246,6 +253,7 @@ struct value ...@@ -246,6 +253,7 @@ struct value
return *this = from_values(rhs); // NOLINT return *this = from_values(rhs); // NOLINT
} }
value& operator=(const char* c);
value& operator=(std::nullptr_t); value& operator=(std::nullptr_t);
value& operator=(const std::initializer_list<value>& i); value& operator=(const std::initializer_list<value>& i);
...@@ -315,8 +323,7 @@ struct value ...@@ -315,8 +323,7 @@ struct value
{ {
switch(this->get_type()) switch(this->get_type())
{ {
case null_type: case null_type: {
{
std::nullptr_t null{}; std::nullptr_t null{};
if(this->key.empty()) if(this->key.empty())
v(null); v(null);
...@@ -325,8 +332,7 @@ struct value ...@@ -325,8 +332,7 @@ struct value
return; return;
} }
#define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \
case vt##_type: \ case vt##_type: { \
{ \
if(this->key.empty()) \ if(this->key.empty()) \
v(this->get_##vt()); \ v(this->get_##vt()); \
else \ else \
...@@ -346,19 +352,17 @@ struct value ...@@ -346,19 +352,17 @@ struct value
{ {
switch(this->get_type()) switch(this->get_type())
{ {
case null_type: case null_type: {
{
std::nullptr_t null{}; std::nullptr_t null{};
v(null); v(null);
return; return;
} }
#define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \
case vt##_type: \ case vt##_type: { \
{ \
v(this->get_##vt()); \ v(this->get_##vt()); \
return; \ return; \
} }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, ) MIGRAPHX_VALUE_GENERATE_CASE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, ) MIGRAPHX_VALUE_GENERATE_CASE(object, )
} }
...@@ -374,11 +378,11 @@ struct value ...@@ -374,11 +378,11 @@ struct value
} }
template <class To> template <class To>
To value_or(const To& default_value) const literal_to_string<To> value_or(const To& default_value) const
{ {
if(this->is_null()) if(this->is_null())
return default_value; return default_value;
return to<To>(); return to<literal_to_string<To>>();
} }
template <class To> template <class To>
...@@ -394,12 +398,12 @@ struct value ...@@ -394,12 +398,12 @@ struct value
} }
template <class To> template <class To>
To get(const std::string& pkey, const To& default_value) const literal_to_string<To> get(const std::string& pkey, const To& default_value) const
{ {
const auto* v = find(pkey); const auto* v = find(pkey);
if(v == this->end()) if(v == this->end())
return default_value; return default_value;
return v->to<To>(); return v->to<literal_to_string<To>>();
} }
template <class To> template <class To>
...@@ -412,10 +416,11 @@ struct value ...@@ -412,10 +416,11 @@ struct value
} }
template <class To> template <class To>
std::vector<To> get(const std::string& pkey, std::vector<literal_to_string<To>> get(const std::string& pkey,
const std::initializer_list<To>& default_value) const const std::initializer_list<To>& default_value) const
{ {
return get<std::vector<To>>(pkey, default_value); return get(pkey,
std::vector<literal_to_string<To>>{default_value.begin(), default_value.end()});
} }
friend bool operator==(const value& x, const value& y); friend bool operator==(const value& x, const value& y);
...@@ -429,6 +434,8 @@ struct value ...@@ -429,6 +434,8 @@ struct value
void debug_print(bool show_type = false) const; void debug_print(bool show_type = false) const;
type_t get_type() const;
private: private:
template <class T> template <class T>
std::vector<value> from_values(const T& r) std::vector<value> from_values(const T& r)
...@@ -438,7 +445,6 @@ struct value ...@@ -438,7 +445,6 @@ struct value
r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); }); r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); });
return v; return v;
} }
type_t get_type() const;
std::shared_ptr<value_base_impl> x; std::shared_ptr<value_base_impl> x;
std::string key; std::string key;
}; };
......
...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins, ...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m) static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{ {
auto op = any_cast<op::pooling>(ins->get_operator()); auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average") if(op.mode == op::pooling_mode::average)
{ {
return; return;
} }
......
...@@ -133,6 +133,12 @@ std::string to_json_string(const value& val) ...@@ -133,6 +133,12 @@ std::string to_json_string(const value& val)
return j.dump(); return j.dump();
} }
std::string to_pretty_json_string(const value& val, std::size_t indent)
{
json j = val;
return j.dump(indent);
}
migraphx::value from_json_string(const char* str, std::size_t size) migraphx::value from_json_string(const char* str, std::size_t size)
{ {
json j = json::parse(str, str + size); json j = json::parse(str, str + size);
......
...@@ -629,8 +629,9 @@ std::unordered_map<instruction_ref, std::string> module::print( ...@@ -629,8 +629,9 @@ std::unordered_map<instruction_ref, std::string> module::print(
var_name = this->name(); var_name = this->name();
var_name.append((this->name().empty() ? "@" : ":@")); var_name.append((this->name().empty() ? "@" : ":@"));
var_name.append(std::to_string(count)); var_name.append(std::to_string(count));
count++;
} }
// count every instruction so index matches loc in the printout program
count++;
names.emplace(ins, var_name); names.emplace(ins, var_name);
print_func(ins, names); print_func(ins, names);
......
...@@ -14,44 +14,36 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -14,44 +14,36 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{ {
switch(o.type) switch(o.type)
{ {
case msgpack::type::NIL: case msgpack::type::NIL: {
{
v = nullptr; v = nullptr;
break; break;
} }
case msgpack::type::BOOLEAN: case msgpack::type::BOOLEAN: {
{
v = o.as<bool>(); v = o.as<bool>();
break; break;
} }
case msgpack::type::POSITIVE_INTEGER: case msgpack::type::POSITIVE_INTEGER: {
{
v = o.as<std::uint64_t>(); v = o.as<std::uint64_t>();
break; break;
} }
case msgpack::type::NEGATIVE_INTEGER: case msgpack::type::NEGATIVE_INTEGER: {
{
v = o.as<std::int64_t>(); v = o.as<std::int64_t>();
break; break;
} }
case msgpack::type::FLOAT32: case msgpack::type::FLOAT32:
case msgpack::type::FLOAT64: case msgpack::type::FLOAT64: {
{
v = o.as<double>(); v = o.as<double>();
break; break;
} }
case msgpack::type::STR: case msgpack::type::STR: {
{
v = o.as<std::string>(); v = o.as<std::string>();
break; break;
} }
case msgpack::type::BIN: case msgpack::type::BIN: {
{
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size}; v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break; break;
} }
case msgpack::type::ARRAY: case msgpack::type::ARRAY: {
{
migraphx::value r = migraphx::value::array{}; migraphx::value r = migraphx::value::array{};
std::for_each( std::for_each(
o.via.array.ptr, o.via.array.ptr,
...@@ -60,8 +52,7 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -60,8 +52,7 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r; v = r;
break; break;
} }
case msgpack::type::MAP: case msgpack::type::MAP: {
{
migraphx::value r = migraphx::value::object{}; migraphx::value r = migraphx::value::object{};
std::for_each(o.via.map.ptr, std::for_each(o.via.map.ptr,
o.via.map.ptr + o.via.map.size, o.via.map.ptr + o.via.map.size,
...@@ -71,7 +62,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -71,7 +62,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r; v = r;
break; break;
} }
case msgpack::type::EXT: { MIGRAPHX_THROW("msgpack EXT type not supported."); case msgpack::type::EXT: {
MIGRAPHX_THROW("msgpack EXT type not supported.");
} }
} }
return o; return o;
......
...@@ -7,7 +7,7 @@ target_compile_options(onnx-proto PRIVATE -w) ...@@ -7,7 +7,7 @@ target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
file(GLOB ONNX_SRCS *.cpp) file(GLOB ONNX_SRCS ${CONFIGURE_DEPENDS} *.cpp)
add_library(migraphx_onnx ${ONNX_SRCS}) add_library(migraphx_onnx ${ONNX_SRCS})
target_include_directories(migraphx_onnx PRIVATE include) target_include_directories(migraphx_onnx PRIVATE include)
set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx) set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
......
...@@ -32,9 +32,20 @@ struct onnx_parser ...@@ -32,9 +32,20 @@ struct onnx_parser
instruction_ref add_bias(const std::vector<instruction_ref>& args, instruction_ref add_bias(const std::vector<instruction_ref>& args,
instruction_ref curr_ins, instruction_ref curr_ins,
uint64_t axis) const; uint64_t axis) const;
instruction_ref add_broadcastable_binary_op(const std::string& op_name, instruction_ref add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0, instruction_ref arg0,
instruction_ref arg1) const; instruction_ref arg1) const;
instruction_ref add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const;
template <class... Ts>
instruction_ref add_common_op(const std::string& op_name, Ts... xs) const
{
return add_common_op(op_name, {xs...});
}
instruction_ref add_instruction(const operation& op, instruction_ref add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const; const std::vector<instruction_ref>& args) const;
......
...@@ -70,12 +70,14 @@ static literal from_repeated(shape::type_t t, const T& r) ...@@ -70,12 +70,14 @@ static literal from_repeated(shape::type_t t, const T& r)
instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const
{ {
if(ins->get_shape().standard()) auto attr = ins->get_operator().to_value();
std::string key = "require_std_shape";
if((attr.get(key, false)) or (not ins->get_shape().standard()))
{ {
return ins; return add_instruction(make_op("contiguous"), ins);
} }
return add_instruction(make_op("contiguous"), ins); return ins;
} }
instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_ref>& args, instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_ref>& args,
...@@ -96,7 +98,13 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s ...@@ -96,7 +98,13 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
instruction_ref arg0, instruction_ref arg0,
instruction_ref arg1) const instruction_ref arg1) const
{ {
return add_common_op(*mod, make_op(op_name), {arg0, arg1}); return this->add_common_op(op_name, arg0, arg1);
}
instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const
{
return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs));
} }
instruction_ref instruction_ref
...@@ -380,8 +388,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const ...@@ -380,8 +388,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data()); case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::UINT64: case onnx::TensorProto::UINT64:
return create_literal(shape::uint64_type, dims, t.uint64_data()); return create_literal(shape::uint64_type, dims, t.uint64_data());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16: {
{
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half; std::vector<half> data_half;
std::transform(data_uint16.begin(), std::transform(data_uint16.begin(),
...@@ -451,7 +458,8 @@ shape::type_t get_type(int dtype) ...@@ -451,7 +458,8 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type; case 11: return shape::double_type;
case 12: return shape::uint32_type; case 12: return shape::uint32_type;
case 13: return shape::uint64_type; case 13: return shape::uint64_type;
default: { MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported"); default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
} }
} }
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/pad_calc.hpp> #include <migraphx/pad_calc.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -94,7 +95,7 @@ void tune_padding_size(const value& v, ...@@ -94,7 +95,7 @@ void tune_padding_size(const value& v,
std::vector<int64_t>& s_start) std::vector<int64_t>& s_start)
{ {
// maxpooling or count_include_pad is 1, no change is required. // maxpooling or count_include_pad is 1, no change is required.
if(v.at("mode").to<std::string>() == "max" or count_include_pad == 1) if(v.at("mode").to<op::pooling_mode>() == op::pooling_mode::max or count_include_pad == 1)
{ {
return; return;
} }
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_celu : op_parser<parse_celu>
{
std::vector<op_desc> operators() const { return {{"Celu"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0;
if(contains(info.attributes, "alpha"))
{
alpha = info.attributes.at("alpha").f();
}
if(float_equal(alpha, 0.0f))
{
MIGRAPHX_THROW("CELU: alpha is zero (division by zero)");
}
auto input_lens = args[0]->get_shape().lens();
auto input_type = args[0]->get_shape().type();
if(input_type != migraphx::shape::float_type)
{
MIGRAPHX_THROW("CELU: input tensor not float type");
}
auto zero_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto alpha_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto linear_part = info.add_instruction(migraphx::make_op("max"), zero_lit, args[0]);
auto divi = info.add_instruction(migraphx::make_op("div"), args[0], alpha_lit);
auto expo = info.add_instruction(migraphx::make_op("exp"), divi);
auto sub = info.add_instruction(migraphx::make_op("sub"), expo, one_lit);
auto mul = info.add_instruction(migraphx::make_op("mul"), alpha_lit, sub);
auto exp_part = info.add_instruction(migraphx::make_op("min"), zero_lit, mul);
return info.add_instruction(migraphx::make_op("add"), linear_part, exp_part);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -16,7 +16,6 @@ struct parse_clip : op_parser<parse_clip> ...@@ -16,7 +16,6 @@ struct parse_clip : op_parser<parse_clip>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto input_lens = args[0]->get_shape().lens();
instruction_ref min_arg; instruction_ref min_arg;
instruction_ref max_arg; instruction_ref max_arg;
bool min_used = false; bool min_used = false;
...@@ -45,29 +44,17 @@ struct parse_clip : op_parser<parse_clip> ...@@ -45,29 +44,17 @@ struct parse_clip : op_parser<parse_clip>
max_used = true; max_used = true;
} }
if(min_used)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
min_arg);
}
if(max_used)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
max_arg);
}
if(min_used and max_used) if(min_used and max_used)
{ {
return info.add_instruction(make_op("clip"), args[0], min_arg, max_arg); return info.add_common_op("clip", args[0], min_arg, max_arg);
} }
else if(max_used) else if(max_used)
{ {
return info.add_instruction(make_op("min"), args[0], max_arg); return info.add_broadcastable_binary_op("min", args[0], max_arg);
} }
else if(min_used) else if(min_used)
{ {
return info.add_instruction(make_op("max"), args[0], min_arg); return info.add_broadcastable_binary_op("max", args[0], min_arg);
} }
else else
{ {
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_eyelike : op_parser<parse_eyelike>
{
std::vector<op_desc> operators() const { return {{"EyeLike"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto input_shape = args[0]->get_shape();
auto input_lens = input_shape.lens();
if(input_lens.size() != 2)
{
MIGRAPHX_THROW("EYELIKE: tensor input not of rank 2");
}
std::ptrdiff_t num_rows = input_lens.front();
std::ptrdiff_t num_cols = input_lens.back();
shape::type_t output_type = args[0]->get_shape().type();
if(contains(info.attributes, "dtype"))
{
output_type = get_type(info.attributes.at("dtype").i());
}
std::ptrdiff_t k = 0;
if(contains(info.attributes, "k"))
{
k = info.attributes.at("k").i();
}
if(k >= 0)
{
if(k >= num_cols)
{
std::ostringstream oss;
oss << "EYELIKE: positive k out of bounds, k = " << k << " num_cols = " << num_cols;
MIGRAPHX_THROW(oss.str());
}
}
else
{
if(std::abs(k) >= num_rows)
{
std::ostringstream oss;
oss << "EYELIKE: negative k out of bounds, k = " << k << " num_rows = " << num_cols;
MIGRAPHX_THROW(oss.str());
}
}
std::vector<char> eyelike_mat(num_rows * num_cols, 0);
for(std::ptrdiff_t i = 0; i < num_rows; ++i)
{
auto idx = i + k;
if(idx < num_cols and idx >= 0)
eyelike_mat[(num_cols + 1) * i + k] = char{1};
}
return info.add_literal(
migraphx::literal{migraphx::shape{output_type, input_lens}, eyelike_mat});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Floor", "floor"}, {"Floor", "floor"},
{"Gather", "gather"}, {"Gather", "gather"},
{"Identity", "identity"}, {"Identity", "identity"},
{"IsNaN", "isnan"},
{"LeakyRelu", "leaky_relu"}, {"LeakyRelu", "leaky_relu"},
{"Log", "log"}, {"Log", "log"},
{"LRN", "lrn"}, {"LRN", "lrn"},
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for LpNormalization ONNX operator.
/*!
Normalizes a tensor by the L1 or L2 norms along a given axis.
Norms that evaluate to 0 are changed to 1 to prevent division by zero.
*/
struct parse_lpnormalization : op_parser<parse_lpnormalization>
{
std::vector<op_desc> operators() const { return {{"LpNormalization"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int p = 2;
if(contains(info.attributes, "p"))
{
p = info.attributes.at("p").i();
}
if(p != 1 and p != 2)
{
MIGRAPHX_THROW("LPNORMALIZATION: only L1 and L2 norm supported");
}
auto input = args.front();
auto input_shape = input->get_shape();
const auto& input_lens = input_shape.lens();
auto input_type = input_shape.type();
std::ptrdiff_t num_axes = input_lens.size();
std::ptrdiff_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = info.attributes.at("axis").i();
if(axis < -num_axes or axis >= num_axes)
{
// handled in normalize_attributes but throwing here might be clearer
MIGRAPHX_THROW("LPNORMALIZATION: selected axis out of bounds");
}
}
migraphx::instruction_ref p_val;
if(p == 1)
{
p_val = info.add_instruction(migraphx::make_op("abs"), input);
}
else
{
p_val = info.add_instruction(migraphx::make_op("mul"), input, input);
}
// need to check for zeros from lp norm to prevent division by zero
// change them to 1 for the element-wise division
auto norms =
info.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {axis}}}), p_val);
if(p == 2)
{
norms = info.add_instruction(migraphx::make_op("sqrt"), norms);
}
// broadcast back to initial shape, negative axis option doesn't work with unidirectional
norms = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), norms);
auto zero_mb = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_mb = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto is_zero = info.add_instruction(migraphx::make_op("equal"), norms, zero_mb);
auto norms_zeros_to_one =
info.add_instruction(migraphx::make_op("where"), is_zero, one_mb, norms);
return info.add_instruction(migraphx::make_op("div"), input, norms_zeros_to_one);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp> #include <migraphx/onnx/padding.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -27,10 +28,16 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -27,10 +28,16 @@ struct parse_pooling : op_parser<parse_pooling>
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
std::string mode = opd.op_name; std::string mode = opd.op_name;
operation op = make_op("pooling", {{"mode", mode}}); if(mode != "max" && mode != "average")
value values = op.to_value(); {
auto l0 = args[0]; MIGRAPHX_THROW("onnx pooling mode must be \"max\" or \"average\"");
auto in_lens = l0->get_shape().lens(); }
operation op = make_op(
"pooling",
{{"mode", mode == "average" ? op::pooling_mode::average : op::pooling_mode::max}});
value values = op.to_value();
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2); assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2; auto kdims = in_lens.size() - 2;
...@@ -72,6 +79,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -72,6 +79,7 @@ struct parse_pooling : op_parser<parse_pooling>
std::vector<int64_t> paddings; std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f); float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads")) if(contains(info.attributes, "pads"))
{ {
values["padding"].clear(); values["padding"].clear();
......
#include <migraphx/op/common.hpp>
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
...@@ -28,10 +29,14 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -28,10 +29,14 @@ struct parse_roialign : op_parser<parse_roialign>
"\": invalid value!"); "\": invalid value!");
} }
std::string mode = "avg"; migraphx::op::pooling_mode rmode(migraphx::op::pooling_mode::average);
if(contains(info.attributes, "mode")) if(contains(info.attributes, "mode"))
{ {
mode = info.attributes.at("mode").s(); // read mode; default is "avg"
if(info.attributes.at("mode").s() == "max")
{
rmode = migraphx::op::pooling_mode::max;
}
} }
int64_t output_height = 1; int64_t output_height = 1;
...@@ -57,10 +62,9 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -57,10 +62,9 @@ struct parse_roialign : op_parser<parse_roialign>
{ {
spatial_scale = info.attributes.at("spatial_scale").f(); spatial_scale = info.attributes.at("spatial_scale").f();
} }
return info.add_instruction(make_op("roialign", return info.add_instruction(make_op("roialign",
{{"coordinate_transformation_mode", coord_trans_mode}, {{"coordinate_transformation_mode", coord_trans_mode},
{"mode", mode}, {"mode", rmode},
{"output_height", output_height}, {"output_height", output_height},
{"output_width", output_width}, {"output_width", output_width},
{"sampling_ratio", sampling_ratio}, {"sampling_ratio", sampling_ratio},
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_scatternd : op_parser<parse_scatternd>
{
std::vector<op_desc> operators() const { return {{"ScatterND"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref>& args) const
{
if(contains(info.attributes, "reduction"))
{
if(info.attributes.at("reduction").s() == "add")
return info.add_instruction(migraphx::make_op("scatternd_add"), args);
if(info.attributes.at("reduction").s() == "mul")
return info.add_instruction(migraphx::make_op("scatternd_mul"), args);
}
return info.add_instruction(migraphx::make_op("scatternd_none"), args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment