Unverified Commit 1dd4e4d9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor onnx parser (#699)



* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold

* Refactor onnx_parser class

* Formatting

* Add op_parser

* Formatting

* Remove old onnx drivers

* Use file GLOB

* Parse arg ops

* Formatting

* Add pooling

* Formatting

* Add parse_natchnorm

* Add more operators

* Formatting

* Add more operators

* Formatting

* Add more operators

* Formatting

* Add more operators

* Add rnn operators

* Formatting

* Fix tidy issues

* Formatting

* Add back missing param

* Formatting

* Fix shadow variable

* Fix shadow in declaration

* Make global constant

* Formatting

* Add generic op

* Formatting

* Add binary op

* Formatting

* Add variadiac op

* Formatting

* Remove unused fields and functions

* Set default values

* Formatting

* Remove unused member variable

* Add add literal overload

* Use info.add_literal

* Formatting

* Call add_instruction through info class

* Fix tidy issues

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 69d2e38f
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_batchnorm : op_parser<parse_batchnorm>
{
std::vector<op_desc> operators() const { return {{"BatchNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
float epsilon = 1e-5f;
float momentum = 0.9f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
if(contains(info.attributes, "momentum"))
{
momentum = parser.parse_value(info.attributes.at("momentum")).at<float>();
}
if(contains(info.attributes, "spatial"))
{
bn_mode = (parser.parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
? op::batch_norm_inference::spatial
: op::batch_norm_inference::per_activation;
}
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return info.add_instruction(op, args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_binary_op : op_parser<parse_binary_op>
{
std::vector<op_desc> operators() const
{
return {{"Add", "add"},
{"Div", "div"},
{"Mul", "mul"},
{"Pow", "pow"},
{"PRelu", "prelu"},
{"Sub", "sub"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
{
uint64_t broadcasted =
parser.parse_value(info.attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
{
uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}),
args[1]);
return info.add_instruction(make_op(opd.op_name), args[0], l);
}
return info.add_instruction(make_op(opd.op_name), args);
}
else
{
return info.add_broadcastable_binary_op(opd.op_name, args[0], args[1]);
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_cast : op_parser<parse_cast>
{
std::vector<op_desc> operators() const { return {{"Cast"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
if(!contains(info.attributes, "to"))
{
MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
}
int to_type = parser.parse_value(info.attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type);
return info.add_instruction(make_op("convert", {{"target_type", type}}), args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_clip : op_parser<parse_clip>
{
std::vector<op_desc> operators() const { return {{"Clip"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto input_lens = args[0]->get_shape().lens();
instruction_ref min_arg;
instruction_ref max_arg;
bool min_used = false;
bool max_used = false;
if(args.size() == 3 and args[2]->name() != "undefined")
{
max_arg = args[2];
max_used = true;
}
if(args.size() >= 2 and args[1]->name() != "undefined")
{
min_arg = args[1];
min_used = true;
}
// if using previous opset for attributes
else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
{
float min_val = parser.parse_value(info.attributes.at("min")).at<float>();
float max_val = parser.parse_value(info.attributes.at("max")).at<float>();
min_arg = info.add_literal(min_val);
max_arg = info.add_literal(max_val);
min_used = true;
max_used = true;
}
if(min_used)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
min_arg);
}
if(max_used)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
max_arg);
}
if(min_used and max_used)
{
return info.add_instruction(make_op("clip"), args[0], min_arg, max_arg);
}
else if(max_used)
{
return info.add_instruction(make_op("min"), args[0], max_arg);
}
else if(min_used)
{
return info.add_instruction(make_op("max"), args[0], min_arg);
}
else
{
return info.add_instruction(make_op("identity"), args[0]);
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_compare_op : op_parser<parse_compare_op>
{
std::vector<op_desc> operators() const
{
return {{"Equal", "equal"}, {"Greater", "greater"}, {"Less", "less"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto l = info.add_broadcastable_binary_op(opd.op_name, args[0], args[1]);
if(l->get_shape().type() != shape::bool_type)
{
l = info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), l);
}
return l;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_constant : op_parser<parse_constant>
{
std::vector<op_desc> operators() const { return {{"Constant"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& /*args*/) const
{
literal v = parser.parse_value(info.attributes.at("value"));
// return empty literal
if(v.get_shape().elements() == 0)
{
return info.add_literal(literal{});
}
auto dim_size = info.attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar
if(dim_size == 0)
{
migraphx::shape scalar_shape{v.get_shape().type()};
return info.add_literal(migraphx::literal{scalar_shape, v.data()});
}
return info.add_literal(v);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill
// operator
struct parse_constant_fill : op_parser<parse_constant_fill>
{
std::vector<op_desc> operators() const { return {{"ConstantFill"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
int input_as_shape = 0;
int dtype = 1;
float value = 0.0f;
if(contains(info.attributes, "dtype"))
{
dtype = parser.parse_value(info.attributes.at("dtype")).at<int>();
}
shape::type_t type = get_type(dtype);
if(contains(info.attributes, "input_as_shape"))
{
input_as_shape = parser.parse_value(info.attributes.at("input_as_shape")).at<int>();
}
if(contains(info.attributes, "value"))
{
value = parser.parse_value(info.attributes.at("value")).at<float>();
}
if(contains(info.attributes, "extra_shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
}
if(input_as_shape == 1)
{
if(args.size() != 1)
{
MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
}
if(contains(info.attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
"at the same time");
}
migraphx::argument in = args[0]->eval();
check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims);
std::vector<float> values(s.elements(), value);
return info.add_literal(migraphx::literal(s, values));
}
else if(input_as_shape == 0)
{
if(!contains(info.attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
}
literal ls = parser.parse_value(info.attributes.at("shape"));
std::vector<std::size_t> dims;
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims};
std::vector<float> values(s.elements(), value);
return info.add_literal(migraphx::literal(s, values));
}
else
{
MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
{
std::vector<op_desc> operators() const { return {{"ConstantOfShape"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
literal l_val{};
if(contains(info.attributes, "value"))
{
l_val = parser.parse_value(info.attributes.at("value"));
if(l_val.get_shape().elements() != 1)
{
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
}
}
else
{
l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
}
// input is empty, output is a scalar
auto type = l_val.get_shape().type();
if(args.empty())
{
MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
}
else
{
migraphx::shape s;
// empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
}
else
{
migraphx::argument in = args[0]->eval();
check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims};
}
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
});
return info.add_literal(l_out);
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/conv.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_convolution : op_parser<parse_convolution>
{
std::vector<op_desc> operators() const
{
return {{"Conv", "convolution"}, {"ConvInteger", "quant_convolution"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto op = make_op(opd.op_name);
auto values = op.to_value();
auto l0 = args[0];
auto weights = args[1];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
// ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "CONV");
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_CONV: inconsistent strides");
}
if(contains(info.attributes, "dilations"))
{
values["dilation"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilation"]));
check_attr_sizes(
kdims, values["dilation"].size(), "PARSE_CONV: inconsistent dilations");
}
std::vector<int64_t> padding;
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(padding));
check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings");
}
if(contains(info.attributes, "auto_pad"))
{
auto weight_lens = weights->get_shape().lens();
std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
cal_auto_padding_size(info,
values,
k_lens,
values["dilation"].to_vector<std::size_t>(),
in_lens,
padding);
auto auto_pad = info.attributes["auto_pad"].s();
if(auto_pad.find("SAME") != std::string::npos)
{
values["padding_mode"] = to_value(op::padding_mode_t::same);
}
}
check_asym_padding(info, l0, padding, values);
if(contains(info.attributes, "group"))
{
values["group"] = parser.parse_value(info.attributes.at("group")).at<int>();
}
recalc_conv_attributes(values, kdims);
op.from_value(values);
auto l1 = info.add_instruction(op, l0, args[1]);
return info.add_bias(args, l1, 1);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/conv.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
template <class T>
std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
{
std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
return output_vector;
}
struct parse_deconvolution : op_parser<parse_deconvolution>
{
std::vector<op_desc> operators() const { return {{"ConvTranspose"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
operation op = make_op("deconvolution");
value values = op.to_value();
// op::deconvolution op;
auto l0 = args[0];
std::vector<std::int64_t> padding;
bool asym_padding = false;
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
// ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "CONV_TRANSPOSE");
if(contains(info.attributes, "pads"))
{
copy(info.attributes["pads"].ints(), std::back_inserter(padding));
asym_padding = is_asym_padding(padding);
if(not asym_padding)
{
size_t pad_ndims = padding.size() / 2;
check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings");
values["padding"].clear();
std::transform(padding.begin(),
padding.begin() + pad_ndims,
std::back_inserter(values["padding"]),
[](auto pad_val) { return pad_val; });
}
}
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(
kdims, values["stride"].size(), "PARSE_CONV_TRANSPOSE: inconsistent strides");
}
if(contains(info.attributes, "dilations"))
{
values["dilation"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilation"]));
check_attr_sizes(
kdims, values["dilation"].size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations");
}
if(contains(info.attributes, "auto_pad"))
{
auto s = info.attributes["auto_pad"].s();
if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: auto_pad and padding cannot be specified "
"simultaneously");
}
if(s.find("SAME") != std::string::npos)
{
values["padding_mode"] = to_value(op::padding_mode_t::same);
}
}
if(contains(info.attributes, "group"))
{
values["group"] = parser.parse_value(info.attributes.at("group")).at<int>();
}
recalc_conv_attributes(values, kdims);
op.from_value(values);
auto l1 = info.add_instruction(op, l0, args[1]);
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
if(asym_padding)
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); // ignore first 2 dims
auto pad_kdim_start = padding.begin() + kdims;
std::vector<int64_t> starts(padding.begin(), pad_kdim_start);
std::vector<int64_t> ends{};
std::transform(curr_shape.begin(),
curr_shape.end(),
pad_kdim_start,
std::back_inserter(ends),
[](auto curr_dim, auto pad_dim) { return curr_dim - pad_dim; });
l1 = info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), l1);
}
if(contains(info.attributes, "output_padding"))
{
size_t non_kdims = dims.size() * 2 - kdims;
std::vector<int64_t> output_padding(non_kdims, 0);
copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
check_attr_sizes(kdims,
output_padding.size() - non_kdims,
"PARSE_CONV_TRANSPOSE: inconsistent output padding");
l1 = info.add_instruction(make_op("pad", {{"pads", output_padding}}), l1);
}
if(contains(info.attributes, "output_shape"))
{
std::vector<int64_t> output_shape;
copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
check_attr_sizes(
kdims, output_shape.size(), "PARSE_CONV_TRANSPOSE: inconsistent output shape");
dims = to_int64_vector(l1->get_shape().lens());
copy(dims.begin() + 2, dims.end(), curr_shape.begin());
if(curr_shape != output_shape)
{
std::vector<int64_t> target_padding(dims.size() * 2 - kdims, 0);
std::transform(output_shape.begin(),
output_shape.end(),
curr_shape.begin(),
std::back_inserter(target_padding),
[](auto out_dim, auto curr_dim) { return out_dim - curr_dim; });
l1 = info.add_instruction(make_op("pad", {{"pads", target_padding}}), l1);
}
}
return info.add_bias(args, l1, 1);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_dropout : op_parser<parse_dropout>
{
std::vector<op_desc> operators() const { return {{"Dropout"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto out = info.add_instruction(make_op("identity"), args[0]);
auto s = args[0]->get_shape();
std::vector<int8_t> vec(s.elements(), 1);
shape mask_s{shape::bool_type, s.lens()};
auto mask = info.add_literal(literal(mask_s, vec));
return {out, mask};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_expand : op_parser<parse_expand>
{
std::vector<op_desc> operators() const { return {{"Expand"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval();
check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
args[0]);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_gather_elements : op_parser<parse_gather_elements>
{
std::vector<op_desc> operators() const { return {{"GatherElements"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
int axis = 0;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
}
// standardize input data and index
auto arg_data = info.make_contiguous(args[0]);
auto arg_ind = info.make_contiguous(args[1]);
auto data_s = arg_data->get_shape();
auto ind_s = arg_ind->get_shape();
if(data_s.lens().size() != ind_s.lens().size())
{
MIGRAPHX_THROW("PARSE_GATHER_ELEMENTS: input data and index must have the same rank!");
}
int n_rank = static_cast<int>(data_s.lens().size());
int tuned_axis = (axis < 0) ? (axis + n_rank) : axis;
auto axis_stride = data_s.strides()[tuned_axis];
int64_t data_elem_num = static_cast<int64_t>(data_s.elements());
// reshape the input data as one dimension and used as input data
// to the gather operator
arg_data = info.add_instruction(make_op("reshape", {{"dims", {data_elem_num}}}), arg_data);
std::size_t elem_num = ind_s.elements();
std::vector<int> ind_index(elem_num);
std::iota(ind_index.begin(), ind_index.end(), 0);
// convert index in input indices to that in input data
std::vector<int> data_indices(elem_num);
std::transform(ind_index.begin(), ind_index.end(), data_indices.begin(), [&](auto i) {
return data_s.index(ind_s.multi(i));
});
std::vector<int> vec_axis_ind(elem_num);
std::transform(ind_index.begin(), ind_index.end(), vec_axis_ind.begin(), [&](auto i) {
return ind_s.multi(i)[tuned_axis];
});
auto l_shape_idx =
info.add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = info.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = info.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = info.add_instruction(make_op("multibroadcast", {{"output_lens", ind_s.lens()}}),
l_stride);
auto dim_diff = info.add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = info.add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = info.add_instruction(make_op("add"), l_shape_idx, delta);
auto op = make_op("gather", {{"axis", 0}});
return info.add_instruction(op, arg_data, ind);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_gemm : op_parser<parse_gemm>
{
std::vector<op_desc> operators() const { return {{"Gemm"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0f;
float beta = 1.0f;
bool transa = false;
bool transb = false;
if(contains(info.attributes, "alpha"))
{
alpha = parser.parse_value(info.attributes.at("alpha")).at<float>();
}
if(contains(info.attributes, "beta"))
{
beta = parser.parse_value(info.attributes.at("beta")).at<float>();
}
if(contains(info.attributes, "transA"))
{
transa = parser.parse_value(info.attributes.at("transA")).at<bool>();
}
if(contains(info.attributes, "transB"))
{
transb = parser.parse_value(info.attributes.at("transB")).at<bool>();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
: args[0];
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
if(args.size() == 3)
{
if(beta != 0.f && args[2]->get_shape().elements() > 0)
{
auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back();
auto l3 = args[2];
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]);
}
return info.add_instruction(
make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
}
}
return info.add_instruction(make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_generic_op : op_parser<parse_generic_op>
{
std::vector<op_desc> operators() const
{
return {{"Abs", "abs"},
{"Acos", "acos"},
{"Acosh", "acosh"},
{"Asin", "asin"},
{"Asinh", "asinh"},
{"Atan", "atan"},
{"Atanh", "atanh"},
{"Ceil", "ceil"},
{"Concat", "concat"},
{"Cos", "cos"},
{"Cosh", "cosh"},
{"Elu", "elu"},
{"Erf", "erf"},
{"Exp", "exp"},
{"Flatten", "flatten"},
{"Floor", "floor"},
{"Gather", "gather"},
{"Identity", "identity"},
{"LeakyRelu", "leaky_relu"},
{"Log", "log"},
{"LogSoftmax", "logsoftmax"},
{"LRN", "lrn"},
{"Neg", "neg"},
{"Reciprocal", "recip"},
{"Relu", "relu"},
{"Round", "round"},
{"Sigmoid", "sigmoid"},
{"Sign", "sign"},
{"Sin", "sin"},
{"Sinh", "sinh"},
{"Softmax", "softmax"},
{"Sqrt", "sqrt"},
{"Squeeze", "squeeze"},
{"Tan", "tan"},
{"Tanh", "tanh"},
{"Unsqueeze", "unsqueeze"}};
}
bool needs_contiguous(const std::string& op_name) const
{
return contains({"gather", "squeeze", "unsqueeze"}, op_name);
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto op = parser.load(opd.op_name, info);
if(needs_contiguous(opd.op_name))
{
std::transform(args.begin(), args.end(), args.begin(), [&](auto arg) {
return info.make_contiguous(arg);
});
}
return info.add_instruction(op, args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/map_activation_functions.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_gru : op_parser<parse_gru>
{
std::vector<op_desc> operators() const { return {{"GRU"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(info.attributes, "hidden_size"))
{
std::size_t hidden_size_att =
parser.parse_value(info.attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(info.attributes, "direction"))
{
direction = info.attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names = {"sigmoid", "tanh"};
if(contains(info.attributes, "activations"))
{
auto names = info.attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
}
// need 4 activation functions
if(dirct == op::rnn_direction::bidirectional)
{
// 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1 four times. If 2 actv functins are provided,
// assume forward and reverse use the same pair of actv
// functions. For the case of 3 actv functions provided,
// assume the 3rd one is repeated once and used by the
// reverse direction.
// This may need change later
if(vec_names.size() == 1)
{
vec_names.insert(vec_names.end(), 3, vec_names.at(0));
}
else if(vec_names.size() == 2)
{
// repeat the activation functions
vec_names.push_back(vec_names.at(0));
vec_names.push_back(vec_names.at(1));
}
else if(vec_names.size() == 3)
{
vec_names.push_back(vec_names.at(2));
}
}
else
{
if(vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_activation_functions().count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(),
vec_names.end(),
vec_actv_funcs.begin(),
[&](const auto& name) { return map_activation_functions().at(name); });
float clip = 0.0;
if(contains(info.attributes, "clip"))
{
clip = parser.parse_value(info.attributes.at("clip")).at<float>();
}
int linear_before_reset = 0;
if(contains(info.attributes, "linear_before_reset"))
{
linear_before_reset =
parser.parse_value(info.attributes.at("linear_before_reset")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 6)
{
auto ins = info.add_instruction(make_op("undefined"));
args.insert(args.end(), 6 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states =
info.add_instruction(make_op("gru",
{{"hidden_size", hidden_size},
{"actv_func", to_value(vec_actv_funcs)},
{"direction", dirct},
{"clip", clip},
{"linear_before_reset", linear_before_reset}}),
args);
// second output for last gru output
auto last_output = info.add_instruction(make_op("rnn_last_hs_output"), hidden_states);
return {hidden_states, last_output};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_imagescalar : op_parser<parse_imagescalar>
{
std::vector<op_desc> operators() const { return {{"ImageScaler"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
float scale = 1.0;
std::vector<float> bias{};
if(contains(info.attributes, "scale"))
{
scale = parser.parse_value(info.attributes.at("scale")).at<float>();
}
if(contains(info.attributes, "bias"))
{
auto&& bias_floats = info.attributes["bias"].floats();
bias = std::vector<float>(bias_floats.begin(), bias_floats.end());
}
auto input_shape = args.front()->get_shape();
auto const& input_lens = input_shape.lens();
auto input_type = input_shape.type();
auto scale_val = info.add_literal(literal{shape{input_type}, {scale}});
auto bias_vals = info.add_literal(literal{shape{input_type, {bias.size()}}, bias});
auto scale_tensor = info.add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", input_lens}}), scale_val);
auto img_scaled =
info.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", input_lens}}), bias_vals);
return info.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_instancenorm : op_parser<parse_instancenorm>
{
std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
auto x = args[0];
auto scale = args[1];
auto bias = args[2];
auto dims = x->get_shape().lens();
auto ndims = dims.size();
assert(ndims >= 2);
auto kdims = ndims - 2;
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), mean);
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(epsilon);
auto epsilon_bcast = info.add_instruction(
make_op("multibroadcast", {{"output_lens", dims}}), epsilon_literal);
auto variance_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), variance);
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), scale);
;
auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), bias);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/map_activation_functions.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv_func_names)
{
// need 6 activation functions for bidirectional directions
if(dirct == op::rnn_direction::bidirectional)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch(actv_func_names.size())
{
case 1:
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(0),
actv_func_names.at(0),
actv_func_names.at(0),
actv_func_names.at(0),
actv_func_names.at(0)};
break;
case 2:
// repeat the 2nd actv func once, then repeat all three another time
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(1),
actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(1)};
break;
case 3:
// repeat all three actv funcs once
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(2),
actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(2)};
break;
case 4:
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(2),
actv_func_names.at(3),
actv_func_names.at(3),
actv_func_names.at(3)};
break;
case 5:
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(2),
actv_func_names.at(3),
actv_func_names.at(4),
actv_func_names.at(4)};
break;
default: break;
}
}
else
{
switch(actv_func_names.size())
{
case 1:
actv_func_names = {actv_func_names.at(0), actv_func_names.at(0), actv_func_names.at(0)};
break;
case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs
actv_func_names = {actv_func_names.at(0), actv_func_names.at(1), actv_func_names.at(1)};
break;
default: break;
}
}
}
struct parse_lstm : op_parser<parse_lstm>
{
std::vector<op_desc> operators() const { return {{"LSTM"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(info.attributes, "hidden_size"))
{
std::size_t hidden_size_att =
parser.parse_value(info.attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(info.attributes, "direction"))
{
direction = info.attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
else if(direction == "forward")
{
dirct = op::rnn_direction::forward;
}
else
{
MIGRAPHX_THROW("LSTM: incorrect direction attribute");
}
std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
if(contains(info.attributes, "activations"))
{
auto names = info.attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
}
lstm_actv_functions(dirct, vec_names);
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_activation_functions().count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(),
vec_names.end(),
vec_actv_funcs.begin(),
[&](const auto& name) { return map_activation_functions().at(name); });
float clip = 0.0;
if(contains(info.attributes, "clip"))
{
clip = parser.parse_value(info.attributes.at("clip")).at<float>();
}
int input_forget = 0;
if(contains(info.attributes, "input_forget"))
{
input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 8)
{
auto ins = info.add_instruction(make_op("undefined"));
args.insert(args.end(), 8 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("lstm",
{{"hidden_size", hidden_size},
{"actv_func", to_value(vec_actv_funcs)},
{"direction", dirct},
{"clip", clip},
{"input_forget", input_forget}}),
args);
auto last_output = info.add_instruction(make_op("rnn_last_hs_output"), hidden_states);
// third output for last cell output
auto last_cell_output =
info.add_instruction(make_op("rnn_last_cell_output"), hidden_states);
return {hidden_states, last_output, last_cell_output};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_matmul : op_parser<parse_matmul>
{
std::vector<op_desc> operators() const
{
return {{"MatMul", "dot"}, {"MatMulInteger", "quant_dot"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto l0 = args[0];
auto l1 = args[1];
auto l0_lens = l0->get_shape().lens();
auto l1_lens = l1->get_shape().lens();
// args[0] is a vector, prepend 1 to the shape
bool is_a_prepended = false;
if(l0_lens.size() == 1)
{
is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1);
l0 = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), args[0]);
}
bool is_b_appended = false;
if(l1_lens.size() == 1)
{
is_b_appended = true;
l1_lens.push_back(1);
l1 = info.add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]);
}
instruction_ref bl0 = l0;
instruction_ref bl1 = l1;
if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
{
auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens = output_lens;
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
l1_broadcasted_lens = output_lens;
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if(l0_lens != l0_broadcasted_lens)
{
bl0 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l0_broadcasted_lens}}), l0);
}
if(l1_lens != l1_broadcasted_lens)
{
bl1 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l1_broadcasted_lens}}), l1);
}
}
auto dot_res =
info.add_instruction(make_op(opd.op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended)
{
dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res);
--num_axis;
}
if(is_b_appended)
{
dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 1}}}), dot_res);
}
return dot_res;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment