#include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace onnx { void lstm_actv_functions(op::rnn_direction dirct, std::vector& actv_func_names) { // need 6 activation functions for bidirectional directions if(dirct == op::rnn_direction::bidirectional) { // 6 activation functions are used in the bidirectional // scenario. No spec is provided in onnx::operator. we // use the algorithm that: if 1 actv function is provided, // repeat 1st six times. If 2 actv functins are provided, // repeat 2nd once, then repeat all three once // if 3 actv funcs are provide, repeat all three once. // the same algorithm is used for 4, 5, and 6 actv funcions // provided. This may need change later switch(actv_func_names.size()) { case 1: actv_func_names = {actv_func_names.at(0), actv_func_names.at(0), actv_func_names.at(0), actv_func_names.at(0), actv_func_names.at(0), actv_func_names.at(0)}; break; case 2: // repeat the 2nd actv func once, then repeat all three another time actv_func_names = {actv_func_names.at(0), actv_func_names.at(1), actv_func_names.at(1), actv_func_names.at(0), actv_func_names.at(1), actv_func_names.at(1)}; break; case 3: // repeat all three actv funcs once actv_func_names = {actv_func_names.at(0), actv_func_names.at(1), actv_func_names.at(2), actv_func_names.at(0), actv_func_names.at(1), actv_func_names.at(2)}; break; case 4: actv_func_names = {actv_func_names.at(0), actv_func_names.at(1), actv_func_names.at(2), actv_func_names.at(3), actv_func_names.at(3), actv_func_names.at(3)}; break; case 5: actv_func_names = {actv_func_names.at(0), actv_func_names.at(1), actv_func_names.at(2), actv_func_names.at(3), actv_func_names.at(4), actv_func_names.at(4)}; break; default: break; } } else { switch(actv_func_names.size()) { case 1: actv_func_names = {actv_func_names.at(0), actv_func_names.at(0), actv_func_names.at(0)}; break; case 2: // repeat the 2nd actv func once, so we have 3 actv funcs actv_func_names = {actv_func_names.at(0), actv_func_names.at(1), actv_func_names.at(1)}; break; default: break; } } } struct parse_lstm : op_parser { std::vector operators() const { return {{"LSTM"}}; } std::vector parse(const op_desc& /*opd*/, const onnx_parser& parser, onnx_parser::node_info info, std::vector args) const { migraphx::shape input_shape = args[0]->get_shape(); std::size_t hidden_size = args[2]->get_shape().lens()[2]; if(contains(info.attributes, "hidden_size")) { std::size_t hidden_size_att = parser.parse_value(info.attributes.at("hidden_size")).at(); if(hidden_size != hidden_size_att) { MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute"); } } // Handling of direction to be added later std::string direction{"forward"}; if(contains(info.attributes, "direction")) { direction = info.attributes.at("direction").s(); } op::rnn_direction dirct = op::rnn_direction::forward; if(direction == "bidirectional") { dirct = op::rnn_direction::bidirectional; } else if(direction == "reverse") { dirct = op::rnn_direction::reverse; } else if(direction == "forward") { dirct = op::rnn_direction::forward; } else { MIGRAPHX_THROW("LSTM: incorrect direction attribute"); } std::vector vec_names = {"sigmoid", "tanh", "tanh"}; if(contains(info.attributes, "activations")) { auto names = info.attributes.at("activations").strings(); vec_names.clear(); vec_names.resize(names.size()); std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) { return to_lower(name); }); } lstm_actv_functions(dirct, vec_names); auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) { return (map_activation_functions().count(name) == 0); }); if(name_it != vec_names.end()) { MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported"); } std::vector vec_actv_funcs(vec_names.size()); std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](const auto& name) { return map_activation_functions().at(name); }); float clip = 0.0; if(contains(info.attributes, "clip")) { clip = parser.parse_value(info.attributes.at("clip")).at(); } int input_forget = 0; if(contains(info.attributes, "input_forget")) { input_forget = parser.parse_value(info.attributes.at("input_forget")).at(); } // append undefined opeator to make 6 arguments if(args.size() < 8) { auto ins = info.add_instruction(make_op("undefined")); args.insert(args.end(), 8 - args.size(), ins); } // first output for concatenation of hidden states auto hidden_states = info.add_instruction(make_op("lstm", {{"hidden_size", hidden_size}, {"actv_func", to_value(vec_actv_funcs)}, {"direction", dirct}, {"clip", clip}, {"input_forget", input_forget}}), args); auto last_output = info.add_instruction(make_op("rnn_last_hs_output"), hidden_states); // third output for last cell output auto last_cell_output = info.add_instruction(make_op("rnn_last_cell_output"), hidden_states); return {hidden_states, last_output, last_cell_output}; } }; } // namespace onnx } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx