Unverified Commit 1dd4e4d9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor onnx parser (#699)



* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold

* Refactor onnx_parser class

* Formatting

* Add op_parser

* Formatting

* Remove old onnx drivers

* Use file GLOB

* Parse arg ops

* Formatting

* Add pooling

* Formatting

* Add parse_natchnorm

* Add more operators

* Formatting

* Add more operators

* Formatting

* Add more operators

* Formatting

* Add more operators

* Add rnn operators

* Formatting

* Fix tidy issues

* Formatting

* Add back missing param

* Formatting

* Fix shadow variable

* Fix shadow in declaration

* Make global constant

* Formatting

* Add generic op

* Formatting

* Add binary op

* Formatting

* Add variadiac op

* Formatting

* Remove unused fields and functions

* Set default values

* Formatting

* Remove unused member variable

* Add add literal overload

* Use info.add_literal

* Formatting

* Call add_instruction through info class

* Fix tidy issues

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 69d2e38f
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_batchnorm : op_parser<parse_batchnorm>
{
std::vector<op_desc> operators() const { return {{"BatchNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
float epsilon = 1e-5f;
float momentum = 0.9f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
if(contains(info.attributes, "momentum"))
{
momentum = parser.parse_value(info.attributes.at("momentum")).at<float>();
}
if(contains(info.attributes, "spatial"))
{
bn_mode = (parser.parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
? op::batch_norm_inference::spatial
: op::batch_norm_inference::per_activation;
}
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return info.add_instruction(op, args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_binary_op : op_parser<parse_binary_op>
{
std::vector<op_desc> operators() const
{
return {{"Add", "add"},
{"Div", "div"},
{"Mul", "mul"},
{"Pow", "pow"},
{"PRelu", "prelu"},
{"Sub", "sub"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
{
uint64_t broadcasted =
parser.parse_value(info.attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
{
uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}),
args[1]);
return info.add_instruction(make_op(opd.op_name), args[0], l);
}
return info.add_instruction(make_op(opd.op_name), args);
}
else
{
return info.add_broadcastable_binary_op(opd.op_name, args[0], args[1]);
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_cast : op_parser<parse_cast>
{
std::vector<op_desc> operators() const { return {{"Cast"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
if(!contains(info.attributes, "to"))
{
MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
}
int to_type = parser.parse_value(info.attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type);
return info.add_instruction(make_op("convert", {{"target_type", type}}), args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_clip : op_parser<parse_clip>
{
std::vector<op_desc> operators() const { return {{"Clip"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto input_lens = args[0]->get_shape().lens();
instruction_ref min_arg;
instruction_ref max_arg;
bool min_used = false;
bool max_used = false;
if(args.size() == 3 and args[2]->name() != "undefined")
{
max_arg = args[2];
max_used = true;
}
if(args.size() >= 2 and args[1]->name() != "undefined")
{
min_arg = args[1];
min_used = true;
}
// if using previous opset for attributes
else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
{
float min_val = parser.parse_value(info.attributes.at("min")).at<float>();
float max_val = parser.parse_value(info.attributes.at("max")).at<float>();
min_arg = info.add_literal(min_val);
max_arg = info.add_literal(max_val);
min_used = true;
max_used = true;
}
if(min_used)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
min_arg);
}
if(max_used)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
max_arg);
}
if(min_used and max_used)
{
return info.add_instruction(make_op("clip"), args[0], min_arg, max_arg);
}
else if(max_used)
{
return info.add_instruction(make_op("min"), args[0], max_arg);
}
else if(min_used)
{
return info.add_instruction(make_op("max"), args[0], min_arg);
}
else
{
return info.add_instruction(make_op("identity"), args[0]);
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_compare_op : op_parser<parse_compare_op>
{
std::vector<op_desc> operators() const
{
return {{"Equal", "equal"}, {"Greater", "greater"}, {"Less", "less"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto l = info.add_broadcastable_binary_op(opd.op_name, args[0], args[1]);
if(l->get_shape().type() != shape::bool_type)
{
l = info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), l);
}
return l;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_constant : op_parser<parse_constant>
{
std::vector<op_desc> operators() const { return {{"Constant"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& /*args*/) const
{
literal v = parser.parse_value(info.attributes.at("value"));
// return empty literal
if(v.get_shape().elements() == 0)
{
return info.add_literal(literal{});
}
auto dim_size = info.attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar
if(dim_size == 0)
{
migraphx::shape scalar_shape{v.get_shape().type()};
return info.add_literal(migraphx::literal{scalar_shape, v.data()});
}
return info.add_literal(v);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill
// operator
struct parse_constant_fill : op_parser<parse_constant_fill>
{
std::vector<op_desc> operators() const { return {{"ConstantFill"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
int input_as_shape = 0;
int dtype = 1;
float value = 0.0f;
if(contains(info.attributes, "dtype"))
{
dtype = parser.parse_value(info.attributes.at("dtype")).at<int>();
}
shape::type_t type = get_type(dtype);
if(contains(info.attributes, "input_as_shape"))
{
input_as_shape = parser.parse_value(info.attributes.at("input_as_shape")).at<int>();
}
if(contains(info.attributes, "value"))
{
value = parser.parse_value(info.attributes.at("value")).at<float>();
}
if(contains(info.attributes, "extra_shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
}
if(input_as_shape == 1)
{
if(args.size() != 1)
{
MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
}
if(contains(info.attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
"at the same time");
}
migraphx::argument in = args[0]->eval();
check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims);
std::vector<float> values(s.elements(), value);
return info.add_literal(migraphx::literal(s, values));
}
else if(input_as_shape == 0)
{
if(!contains(info.attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
}
literal ls = parser.parse_value(info.attributes.at("shape"));
std::vector<std::size_t> dims;
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims};
std::vector<float> values(s.elements(), value);
return info.add_literal(migraphx::literal(s, values));
}
else
{
MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
{
std::vector<op_desc> operators() const { return {{"ConstantOfShape"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
literal l_val{};
if(contains(info.attributes, "value"))
{
l_val = parser.parse_value(info.attributes.at("value"));
if(l_val.get_shape().elements() != 1)
{
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
}
}
else
{
l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
}
// input is empty, output is a scalar
auto type = l_val.get_shape().type();
if(args.empty())
{
MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
}
else
{
migraphx::shape s;
// empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
}
else
{
migraphx::argument in = args[0]->eval();
check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims};
}
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
});
return info.add_literal(l_out);
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
This diff is collapsed.
This diff is collapsed.
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_dropout : op_parser<parse_dropout>
{
std::vector<op_desc> operators() const { return {{"Dropout"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto out = info.add_instruction(make_op("identity"), args[0]);
auto s = args[0]->get_shape();
std::vector<int8_t> vec(s.elements(), 1);
shape mask_s{shape::bool_type, s.lens()};
auto mask = info.add_literal(literal(mask_s, vec));
return {out, mask};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment