Commit 3af87aaf authored by Shucai Xiao's avatar Shucai Xiao Committed by mvermeulen
Browse files

Handle optional outputs (#418)

* change to support optional outputs

* clang format

* add onnx test for better code coverage

* add corresponding onnx file

* fix review comments of handling optional program outputs

* clang format

* change onnx unit test to pass

* clang format

* refine onnx unit tests

* clang format

* remove unnecessary code
parent 992666e6
...@@ -1454,6 +1454,19 @@ struct onnx_parser ...@@ -1454,6 +1454,19 @@ struct onnx_parser
{ {
this->parse_node(output.name()); this->parse_node(output.name());
} }
// For now, the last output with a valid name is considered
// as the program output, and add an identity instruction at
// the program end
auto prog_output = graph.output();
auto oit = std::find_if(prog_output.rbegin(), prog_output.rend(), [](auto& node) {
return !node.name().empty();
});
if(instructions.count(oit->name()) > 0)
{
prog.add_instruction(op::identity{}, instructions[oit->name()]);
}
} }
void parse_undefined(const std::string& name) void parse_undefined(const std::string& name)
...@@ -1472,14 +1485,14 @@ struct onnx_parser ...@@ -1472,14 +1485,14 @@ struct onnx_parser
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
if(nodes.count(input) > 0) if(input.empty())
{ {
assert(name != input); this->parse_undefined(input);
this->parse_node(input);
} }
else if(input.empty()) else if(nodes.count(input) > 0)
{ {
this->parse_undefined(input); assert(name != input);
this->parse_node(input);
} }
args.push_back(instructions.at(input)); args.push_back(instructions.at(input));
} }
...@@ -1499,12 +1512,12 @@ struct onnx_parser ...@@ -1499,12 +1512,12 @@ struct onnx_parser
} }
else else
{ {
assert(node.output().size() >= result.size()); assert(node.output().size() <= result.size());
std::transform(result.begin(), std::transform(node.output().begin(),
result.end(), node.output().end(),
node.output().begin(), result.begin(),
std::inserter(instructions, instructions.end()), std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(y, x); }); [](auto&& x, auto&& y) { return std::make_pair(x, y); });
} }
} }
} }
......
...@@ -4,9 +4,25 @@ ...@@ -4,9 +4,25 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = true)
{
auto prog = migraphx::parse_onnx(name);
if(eliminate_deadcode)
migraphx::run_passes(prog, {migraphx::dead_code_elimination{}});
// remove the last identity instruction
auto last_ins = std::prev(prog.end());
prog.remove_instruction(last_ins);
return prog;
}
TEST_CASE(rnn_test_bidirectional) TEST_CASE(rnn_test_bidirectional)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
...@@ -43,7 +59,7 @@ TEST_CASE(rnn_test_bidirectional) ...@@ -43,7 +59,7 @@ TEST_CASE(rnn_test_bidirectional)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_bi.onnx"); auto prog = optimize_onnx("onnx_rnn_bi.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -85,7 +101,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -85,7 +101,7 @@ TEST_CASE(rnn_test_one_direction)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_forward.onnx"); auto prog = optimize_onnx("onnx_rnn_forward.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -111,7 +127,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -111,7 +127,7 @@ TEST_CASE(rnn_test_one_direction)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_reverse.onnx"); auto prog = optimize_onnx("onnx_rnn_reverse.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -135,7 +151,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -135,7 +151,7 @@ TEST_CASE(rnn_test_one_direction)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx"); auto prog = optimize_onnx("onnx_rnn_3args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -163,7 +179,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -163,7 +179,7 @@ TEST_CASE(rnn_test_one_direction)
seq_len, seq_len,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx"); auto prog = optimize_onnx("onnx_rnn_5args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -207,7 +223,7 @@ TEST_CASE(gru_test) ...@@ -207,7 +223,7 @@ TEST_CASE(gru_test)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx"); auto prog = optimize_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -241,7 +257,7 @@ TEST_CASE(gru_test) ...@@ -241,7 +257,7 @@ TEST_CASE(gru_test)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx"); auto prog = optimize_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -278,7 +294,7 @@ TEST_CASE(gru_test) ...@@ -278,7 +294,7 @@ TEST_CASE(gru_test)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx"); auto prog = optimize_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -317,7 +333,7 @@ TEST_CASE(gru_test_args) ...@@ -317,7 +333,7 @@ TEST_CASE(gru_test_args)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx"); auto prog = optimize_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -349,7 +365,7 @@ TEST_CASE(gru_test_args) ...@@ -349,7 +365,7 @@ TEST_CASE(gru_test_args)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx"); auto prog = optimize_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -386,7 +402,7 @@ TEST_CASE(gru_test_args) ...@@ -386,7 +402,7 @@ TEST_CASE(gru_test_args)
seq_len, seq_len,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx"); auto prog = optimize_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -432,7 +448,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -432,7 +448,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx"); auto prog = optimize_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -469,7 +485,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -469,7 +485,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx"); auto prog = optimize_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -506,7 +522,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -506,7 +522,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx"); auto prog = optimize_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -543,7 +559,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -543,7 +559,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx"); auto prog = optimize_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -577,7 +593,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -577,7 +593,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx"); auto prog = optimize_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -611,7 +627,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -611,7 +627,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx"); auto prog = optimize_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -660,8 +676,7 @@ TEST_CASE(lstm_forward) ...@@ -660,8 +676,7 @@ TEST_CASE(lstm_forward)
ic, ic,
pph); pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_forward.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_forward.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -690,8 +705,93 @@ TEST_CASE(lstm_forward) ...@@ -690,8 +705,93 @@ TEST_CASE(lstm_forward)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f3args.onnx");
EXPECT(p == prog);
}
// 3 args, hs output
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
auto prog = optimize_onnx("onnx_lstm_hs.onnx");
EXPECT(p == prog);
}
// 3 args, last output
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_last.onnx");
EXPECT(p == prog);
}
// 3 args, cell output
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f3args.onnx"); auto prog = optimize_onnx("onnx_lstm_cell.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -721,8 +821,7 @@ TEST_CASE(lstm_forward) ...@@ -721,8 +821,7 @@ TEST_CASE(lstm_forward)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_f4args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_f4args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -752,9 +851,8 @@ TEST_CASE(lstm_forward) ...@@ -752,9 +851,8 @@ TEST_CASE(lstm_forward)
und, und,
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f5args.onnx"); auto prog = optimize_onnx("onnx_lstm_f5args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -785,9 +883,8 @@ TEST_CASE(lstm_forward) ...@@ -785,9 +883,8 @@ TEST_CASE(lstm_forward)
ih, ih,
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f6args.onnx"); auto prog = optimize_onnx("onnx_lstm_f6args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -819,9 +916,8 @@ TEST_CASE(lstm_forward) ...@@ -819,9 +916,8 @@ TEST_CASE(lstm_forward)
ih, ih,
ic, ic,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f7args.onnx"); auto prog = optimize_onnx("onnx_lstm_f7args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -866,8 +962,7 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -866,8 +962,7 @@ TEST_CASE(lstm_forward_actv_func)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_f0af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_f0af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -897,8 +992,7 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -897,8 +992,7 @@ TEST_CASE(lstm_forward_actv_func)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_f1af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_f1af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -928,9 +1022,8 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -928,9 +1022,8 @@ TEST_CASE(lstm_forward_actv_func)
und, und,
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f2af.onnx"); auto prog = optimize_onnx("onnx_lstm_f2af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -979,8 +1072,7 @@ TEST_CASE(lstm_reverse) ...@@ -979,8 +1072,7 @@ TEST_CASE(lstm_reverse)
ic, ic,
pph); pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_reverse.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_reverse.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1010,9 +1102,8 @@ TEST_CASE(lstm_reverse) ...@@ -1010,9 +1102,8 @@ TEST_CASE(lstm_reverse)
und, und,
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_r5args.onnx"); auto prog = optimize_onnx("onnx_lstm_r5args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1041,8 +1132,7 @@ TEST_CASE(lstm_reverse) ...@@ -1041,8 +1132,7 @@ TEST_CASE(lstm_reverse)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_r0af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_r0af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1095,8 +1185,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1095,8 +1185,7 @@ TEST_CASE(lstm_bidirectional)
ic, ic,
pph); pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1129,8 +1218,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1129,8 +1218,7 @@ TEST_CASE(lstm_bidirectional)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi3args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi3args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1164,8 +1252,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1164,8 +1252,7 @@ TEST_CASE(lstm_bidirectional)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi4args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi4args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1200,8 +1287,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1200,8 +1287,7 @@ TEST_CASE(lstm_bidirectional)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi5args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi5args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1237,8 +1323,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1237,8 +1323,7 @@ TEST_CASE(lstm_bidirectional)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi6args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi6args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1275,8 +1360,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1275,8 +1360,7 @@ TEST_CASE(lstm_bidirectional)
ic, ic,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi7args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi7args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1326,8 +1410,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1326,8 +1410,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi0af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi0af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1361,8 +1444,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1361,8 +1444,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi1af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi1af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1397,8 +1479,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1397,8 +1479,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi2af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi2af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1434,8 +1515,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1434,8 +1515,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi4af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi4af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1472,8 +1552,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1472,8 +1552,7 @@ TEST_CASE(lstm_bi_actv_funcs)
ic, ic,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi5af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi5af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1506,8 +1585,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1506,8 +1585,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi6af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi6af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment