#include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace onnx { struct parse_layernorm : op_parser { std::vector operators() const { return {{"LayerNormalization"}}; } instruction_ref parse(const op_desc& /*opd*/, const onnx_parser& parser, onnx_parser::node_info info, const std::vector& args) const { float epsilon = 1e-3f; int64_t axis = -1; if(contains(info.attributes, "epsilon")) { epsilon = parser.parse_value(info.attributes.at("epsilon")).at(); } if(contains(info.attributes, "axis")) { epsilon = parser.parse_value(info.attributes.at("axis")).at(); } auto layernorm = info.add_instruction(make_op("layernorm", {{"epsilon", epsilon}, {"axis", axis}}), args.front()); if (args.size() == 3) { layernorm = info.add_broadcastable_binary_op("mul", layernorm, args.at(1)); layernorm = info.add_broadcastable_binary_op("add", layernorm, args.at(2)); } return layernorm; } }; } // namespace onnx } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx