#include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace onnx { struct parse_binary_op : op_parser { std::vector operators() const { return {{"Add", "add"}, {"Div", "div"}, {"Mul", "mul"}, {"Pow", "pow"}, {"PRelu", "prelu"}, {"Sub", "sub"}}; } instruction_ref parse(const op_desc& opd, const onnx_parser& parser, onnx_parser::node_info info, std::vector args) const { if(args.size() != 2) MIGRAPHX_THROW("binary operators should have 2 operands"); if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis")) { uint64_t broadcasted = parser.parse_value(info.attributes.at("broadcast")).at(); if(broadcasted != 0) { uint64_t axis = parser.parse_value(info.attributes.at("axis")).at(); auto l = info.add_instruction( make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), args[1]); return info.add_instruction(make_op(opd.op_name), args[0], l); } return info.add_instruction(make_op(opd.op_name), args); } else { return info.add_broadcastable_binary_op(opd.op_name, args[0], args[1]); } } }; } // namespace onnx } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx