#include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace onnx { struct parse_rnn : op_parser { std::vector operators() const { return {{"RNN"}}; } std::vector parse(const op_desc& /*opd*/, const onnx_parser& parser, onnx_parser::node_info info, std::vector args) const { migraphx::shape input_shape = args[0]->get_shape(); std::size_t hidden_size = args[1]->get_shape().lens()[1]; if(contains(info.attributes, "hidden_size")) { std::size_t hidden_size_att = parser.parse_value(info.attributes.at("hidden_size")).at(); if(hidden_size != hidden_size_att) { MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute"); } } // Handling of direction to be added later std::string direction{"forward"}; if(contains(info.attributes, "direction")) { direction = info.attributes.at("direction").s(); } op::rnn_direction dirct = op::rnn_direction::forward; if(direction == "bidirectional") { dirct = op::rnn_direction::bidirectional; } else if(direction == "reverse") { dirct = op::rnn_direction::reverse; } std::vector vec_names{"tanh"}; if(contains(info.attributes, "activations")) { auto names = info.attributes.at("activations").strings(); vec_names.clear(); vec_names.resize(names.size()); std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) { return to_lower(name); }); } auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) { return (map_activation_functions().count(name) == 0); }); if(name_it != vec_names.end()) { MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported"); } // bidirectional case should have two activation functions. // one is for forward, and the other is for reverse. // if only one actv function is provided, we use it in both // forward and reverse direction if(dirct == op::rnn_direction::bidirectional) { if(vec_names.size() == 1) { vec_names.push_back(vec_names.at(0)); } } std::vector vec_actv_funcs(vec_names.size()); std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](const auto& fn) { return map_activation_functions().at(fn); }); // To be added later float clip = 0.0; if(contains(info.attributes, "clip")) { clip = parser.parse_value(info.attributes.at("clip")).at(); } // if the number of arguments is less than 6, append // undefined operator to have 6 arguments if(args.size() < 6) { auto ins = info.add_instruction(make_op("undefined")); args.insert(args.end(), (6 - args.size()), ins); } // first output for the concatenation of hidden states auto hidden_states = info.add_instruction(make_op("rnn", {{"hidden_size", hidden_size}, {"actv_func", to_value(vec_actv_funcs)}, {"direction", dirct}, {"clip", clip}}), args); // second output for the last hidden state auto last_output = info.add_instruction(make_op("rnn_last_hs_output"), hidden_states); return {hidden_states, last_output}; } }; } // namespace onnx } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx