Commit 3af87aaf authored by Shucai Xiao's avatar Shucai Xiao Committed by mvermeulen
Browse files

Handle optional outputs (#418)

* change to support optional outputs

* clang format

* add onnx test for better code coverage

* add corresponding onnx file

* fix review comments of handling optional program outputs

* clang format

* change onnx unit test to pass

* clang format

* refine onnx unit tests

* clang format

* remove unnecessary code
parent 992666e6
...@@ -1454,6 +1454,19 @@ struct onnx_parser ...@@ -1454,6 +1454,19 @@ struct onnx_parser
{ {
this->parse_node(output.name()); this->parse_node(output.name());
} }
// For now, the last output with a valid name is considered
// as the program output, and add an identity instruction at
// the program end
auto prog_output = graph.output();
auto oit = std::find_if(prog_output.rbegin(), prog_output.rend(), [](auto& node) {
return !node.name().empty();
});
if(instructions.count(oit->name()) > 0)
{
prog.add_instruction(op::identity{}, instructions[oit->name()]);
}
} }
void parse_undefined(const std::string& name) void parse_undefined(const std::string& name)
...@@ -1472,14 +1485,14 @@ struct onnx_parser ...@@ -1472,14 +1485,14 @@ struct onnx_parser
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
if(nodes.count(input) > 0) if(input.empty())
{ {
assert(name != input); this->parse_undefined(input);
this->parse_node(input);
} }
else if(input.empty()) else if(nodes.count(input) > 0)
{ {
this->parse_undefined(input); assert(name != input);
this->parse_node(input);
} }
args.push_back(instructions.at(input)); args.push_back(instructions.at(input));
} }
...@@ -1499,12 +1512,12 @@ struct onnx_parser ...@@ -1499,12 +1512,12 @@ struct onnx_parser
} }
else else
{ {
assert(node.output().size() >= result.size()); assert(node.output().size() <= result.size());
std::transform(result.begin(), std::transform(node.output().begin(),
result.end(), node.output().end(),
node.output().begin(), result.begin(),
std::inserter(instructions, instructions.end()), std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(y, x); }); [](auto&& x, auto&& y) { return std::make_pair(x, y); });
} }
} }
} }
......
...@@ -4,9 +4,25 @@ ...@@ -4,9 +4,25 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = true)
{
auto prog = migraphx::parse_onnx(name);
if(eliminate_deadcode)
migraphx::run_passes(prog, {migraphx::dead_code_elimination{}});
// remove the last identity instruction
auto last_ins = std::prev(prog.end());
prog.remove_instruction(last_ins);
return prog;
}
TEST_CASE(rnn_test_bidirectional) TEST_CASE(rnn_test_bidirectional)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
...@@ -43,7 +59,7 @@ TEST_CASE(rnn_test_bidirectional) ...@@ -43,7 +59,7 @@ TEST_CASE(rnn_test_bidirectional)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_bi.onnx"); auto prog = optimize_onnx("onnx_rnn_bi.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -85,7 +101,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -85,7 +101,7 @@ TEST_CASE(rnn_test_one_direction)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_forward.onnx"); auto prog = optimize_onnx("onnx_rnn_forward.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -111,7 +127,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -111,7 +127,7 @@ TEST_CASE(rnn_test_one_direction)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_reverse.onnx"); auto prog = optimize_onnx("onnx_rnn_reverse.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -135,7 +151,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -135,7 +151,7 @@ TEST_CASE(rnn_test_one_direction)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx"); auto prog = optimize_onnx("onnx_rnn_3args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -163,7 +179,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -163,7 +179,7 @@ TEST_CASE(rnn_test_one_direction)
seq_len, seq_len,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx"); auto prog = optimize_onnx("onnx_rnn_5args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -207,7 +223,7 @@ TEST_CASE(gru_test) ...@@ -207,7 +223,7 @@ TEST_CASE(gru_test)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx"); auto prog = optimize_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -241,7 +257,7 @@ TEST_CASE(gru_test) ...@@ -241,7 +257,7 @@ TEST_CASE(gru_test)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx"); auto prog = optimize_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -278,7 +294,7 @@ TEST_CASE(gru_test) ...@@ -278,7 +294,7 @@ TEST_CASE(gru_test)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx"); auto prog = optimize_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -317,7 +333,7 @@ TEST_CASE(gru_test_args) ...@@ -317,7 +333,7 @@ TEST_CASE(gru_test_args)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx"); auto prog = optimize_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -349,7 +365,7 @@ TEST_CASE(gru_test_args) ...@@ -349,7 +365,7 @@ TEST_CASE(gru_test_args)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx"); auto prog = optimize_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -386,7 +402,7 @@ TEST_CASE(gru_test_args) ...@@ -386,7 +402,7 @@ TEST_CASE(gru_test_args)
seq_len, seq_len,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx"); auto prog = optimize_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -432,7 +448,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -432,7 +448,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx"); auto prog = optimize_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -469,7 +485,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -469,7 +485,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx"); auto prog = optimize_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -506,7 +522,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -506,7 +522,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx"); auto prog = optimize_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -543,7 +559,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -543,7 +559,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx"); auto prog = optimize_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -577,7 +593,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -577,7 +593,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx"); auto prog = optimize_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -611,7 +627,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -611,7 +627,7 @@ TEST_CASE(gru_test_actv_funcs)
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx"); auto prog = optimize_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -660,8 +676,7 @@ TEST_CASE(lstm_forward) ...@@ -660,8 +676,7 @@ TEST_CASE(lstm_forward)
ic, ic,
pph); pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_forward.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_forward.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -690,8 +705,93 @@ TEST_CASE(lstm_forward) ...@@ -690,8 +705,93 @@ TEST_CASE(lstm_forward)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f3args.onnx");
EXPECT(p == prog);
}
// 3 args, hs output
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
auto prog = optimize_onnx("onnx_lstm_hs.onnx");
EXPECT(p == prog);
}
// 3 args, last output
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_last.onnx");
EXPECT(p == prog);
}
// 3 args, cell output
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f3args.onnx"); auto prog = optimize_onnx("onnx_lstm_cell.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -721,8 +821,7 @@ TEST_CASE(lstm_forward) ...@@ -721,8 +821,7 @@ TEST_CASE(lstm_forward)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_f4args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_f4args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -752,9 +851,8 @@ TEST_CASE(lstm_forward) ...@@ -752,9 +851,8 @@ TEST_CASE(lstm_forward)
und, und,
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f5args.onnx"); auto prog = optimize_onnx("onnx_lstm_f5args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -785,9 +883,8 @@ TEST_CASE(lstm_forward) ...@@ -785,9 +883,8 @@ TEST_CASE(lstm_forward)
ih, ih,
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f6args.onnx"); auto prog = optimize_onnx("onnx_lstm_f6args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -819,9 +916,8 @@ TEST_CASE(lstm_forward) ...@@ -819,9 +916,8 @@ TEST_CASE(lstm_forward)
ih, ih,
ic, ic,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f7args.onnx"); auto prog = optimize_onnx("onnx_lstm_f7args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -866,8 +962,7 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -866,8 +962,7 @@ TEST_CASE(lstm_forward_actv_func)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_f0af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_f0af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -897,8 +992,7 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -897,8 +992,7 @@ TEST_CASE(lstm_forward_actv_func)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_f1af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_f1af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -928,9 +1022,8 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -928,9 +1022,8 @@ TEST_CASE(lstm_forward_actv_func)
und, und,
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f2af.onnx"); auto prog = optimize_onnx("onnx_lstm_f2af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -979,8 +1072,7 @@ TEST_CASE(lstm_reverse) ...@@ -979,8 +1072,7 @@ TEST_CASE(lstm_reverse)
ic, ic,
pph); pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_reverse.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_reverse.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1010,9 +1102,8 @@ TEST_CASE(lstm_reverse) ...@@ -1010,9 +1102,8 @@ TEST_CASE(lstm_reverse)
und, und,
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_r5args.onnx"); auto prog = optimize_onnx("onnx_lstm_r5args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1041,8 +1132,7 @@ TEST_CASE(lstm_reverse) ...@@ -1041,8 +1132,7 @@ TEST_CASE(lstm_reverse)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_r0af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_r0af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1095,8 +1185,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1095,8 +1185,7 @@ TEST_CASE(lstm_bidirectional)
ic, ic,
pph); pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1129,8 +1218,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1129,8 +1218,7 @@ TEST_CASE(lstm_bidirectional)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi3args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi3args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1164,8 +1252,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1164,8 +1252,7 @@ TEST_CASE(lstm_bidirectional)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi4args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi4args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1200,8 +1287,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1200,8 +1287,7 @@ TEST_CASE(lstm_bidirectional)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi5args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi5args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1237,8 +1323,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1237,8 +1323,7 @@ TEST_CASE(lstm_bidirectional)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi6args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi6args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1275,8 +1360,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -1275,8 +1360,7 @@ TEST_CASE(lstm_bidirectional)
ic, ic,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi7args.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi7args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1326,8 +1410,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1326,8 +1410,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi0af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi0af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1361,8 +1444,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1361,8 +1444,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi1af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi1af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1397,8 +1479,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1397,8 +1479,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi2af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi2af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1434,8 +1515,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1434,8 +1515,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi4af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi4af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1472,8 +1552,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1472,8 +1552,7 @@ TEST_CASE(lstm_bi_actv_funcs)
ic, ic,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi5af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi5af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1506,8 +1585,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1506,8 +1585,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und, und,
und); und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = optimize_onnx("onnx_lstm_bi6af.onnx");
auto prog = migraphx::parse_onnx("onnx_lstm_bi6af.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
......
...@@ -5,16 +5,35 @@ ...@@ -5,16 +5,35 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = false)
{
auto prog = migraphx::parse_onnx(name);
if(eliminate_deadcode)
migraphx::run_passes(prog, {migraphx::dead_code_elimination{}});
// remove the last identity instruction
auto last_ins = std::prev(prog.end());
if(last_ins->name() == "identity")
{
prog.remove_instruction(last_ins);
}
return prog;
}
TEST_CASE(acos_test) TEST_CASE(acos_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::acos{}, input); p.add_instruction(migraphx::op::acos{}, input);
auto prog = migraphx::parse_onnx("acos_test.onnx"); auto prog = optimize_onnx("acos_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -27,7 +46,7 @@ TEST_CASE(add_bcast_test) ...@@ -27,7 +46,7 @@ TEST_CASE(add_bcast_test)
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2); p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_onnx("add_bcast_test.onnx"); auto prog = optimize_onnx("add_bcast_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -40,7 +59,7 @@ TEST_CASE(add_fp16_test) ...@@ -40,7 +59,7 @@ TEST_CASE(add_fp16_test)
auto l1 = auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {2.5}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {2.5}});
p.add_instruction(migraphx::op::add{}, l0, l1); p.add_instruction(migraphx::op::add{}, l0, l1);
auto prog = migraphx::parse_onnx("add_fp16_test.onnx"); auto prog = optimize_onnx("add_fp16_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -52,7 +71,7 @@ TEST_CASE(add_scalar_test) ...@@ -52,7 +71,7 @@ TEST_CASE(add_scalar_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l0, m1); p.add_instruction(migraphx::op::add{}, l0, m1);
auto prog = migraphx::parse_onnx("add_scalar_test.onnx"); auto prog = optimize_onnx("add_scalar_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -63,7 +82,7 @@ TEST_CASE(argmax_test) ...@@ -63,7 +82,7 @@ TEST_CASE(argmax_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = p.add_instruction(migraphx::op::argmax{2}, l0); auto ins = p.add_instruction(migraphx::op::argmax{2}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, ins); p.add_instruction(migraphx::op::squeeze{{2}}, ins);
auto prog = migraphx::parse_onnx("argmax_test.onnx"); auto prog = optimize_onnx("argmax_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -74,7 +93,7 @@ TEST_CASE(argmin_test) ...@@ -74,7 +93,7 @@ TEST_CASE(argmin_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = p.add_instruction(migraphx::op::argmin{3}, l0); auto ins = p.add_instruction(migraphx::op::argmin{3}, l0);
p.add_instruction(migraphx::op::squeeze{{3}}, ins); p.add_instruction(migraphx::op::squeeze{{3}}, ins);
auto prog = migraphx::parse_onnx("argmin_test.onnx"); auto prog = optimize_onnx("argmin_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -85,7 +104,7 @@ TEST_CASE(asin_test) ...@@ -85,7 +104,7 @@ TEST_CASE(asin_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::asin{}, input); p.add_instruction(migraphx::op::asin{}, input);
auto prog = migraphx::parse_onnx("asin_test.onnx"); auto prog = optimize_onnx("asin_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -96,7 +115,7 @@ TEST_CASE(atan_test) ...@@ -96,7 +115,7 @@ TEST_CASE(atan_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::atan{}, input); p.add_instruction(migraphx::op::atan{}, input);
auto prog = migraphx::parse_onnx("atan_test.onnx"); auto prog = optimize_onnx("atan_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -107,7 +126,7 @@ TEST_CASE(cast_test) ...@@ -107,7 +126,7 @@ TEST_CASE(cast_test)
auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {10}}); auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {10}});
p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, l); p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, l);
auto prog = migraphx::parse_onnx("cast_test.onnx"); auto prog = optimize_onnx("cast_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -117,7 +136,7 @@ TEST_CASE(ceil_test) ...@@ -117,7 +136,7 @@ TEST_CASE(ceil_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::ceil{}, input); p.add_instruction(migraphx::op::ceil{}, input);
auto prog = migraphx::parse_onnx("ceil_test.onnx"); auto prog = optimize_onnx("ceil_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -127,7 +146,7 @@ TEST_CASE(clip_test) ...@@ -127,7 +146,7 @@ TEST_CASE(clip_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0); p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0);
auto prog = migraphx::parse_onnx("clip_test.onnx"); auto prog = optimize_onnx("clip_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -138,7 +157,7 @@ TEST_CASE(concat_test) ...@@ -138,7 +157,7 @@ TEST_CASE(concat_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7, 4, 3}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7, 4, 3}});
p.add_instruction(migraphx::op::concat{0}, l0, l1); p.add_instruction(migraphx::op::concat{0}, l0, l1);
auto prog = migraphx::parse_onnx("concat_test.onnx"); auto prog = optimize_onnx("concat_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -147,7 +166,7 @@ TEST_CASE(constant_test) ...@@ -147,7 +166,7 @@ TEST_CASE(constant_test)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0, 1, 2}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0, 1, 2}});
auto prog = migraphx::parse_onnx("constant_test.onnx"); auto prog = optimize_onnx("constant_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -159,7 +178,7 @@ TEST_CASE(constant_fill_test) ...@@ -159,7 +178,7 @@ TEST_CASE(constant_fill_test)
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> value(s.elements(), 1.0); std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value}); p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("constant_fill_test.onnx"); auto prog = optimize_onnx("constant_fill_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -174,7 +193,7 @@ TEST_CASE(constant_fill_input_as_shape_test) ...@@ -174,7 +193,7 @@ TEST_CASE(constant_fill_input_as_shape_test)
migraphx::shape s{migraphx::shape::float_type, dims}; migraphx::shape s{migraphx::shape::float_type, dims};
std::vector<float> value(s.elements(), 1.0); std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value}); p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("constant_fill_input_as_shape_test.onnx"); auto prog = optimize_onnx("constant_fill_input_as_shape_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -183,7 +202,7 @@ TEST_CASE(constant_scalar_test) ...@@ -183,7 +202,7 @@ TEST_CASE(constant_scalar_test)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}}, {1}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}}, {1}});
auto prog = migraphx::parse_onnx("constant_scalar_test.onnx"); auto prog = optimize_onnx("constant_scalar_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -196,7 +215,7 @@ TEST_CASE(const_of_shape_empty_input_test) ...@@ -196,7 +215,7 @@ TEST_CASE(const_of_shape_empty_input_test)
std::vector<int64_t> vec(s.elements(), 10); std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec)); p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape_empty_input_test.onnx"); auto prog = optimize_onnx("const_of_shape_empty_input_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -209,7 +228,7 @@ TEST_CASE(const_of_shape_float_test) ...@@ -209,7 +228,7 @@ TEST_CASE(const_of_shape_float_test)
std::vector<float> vec(s.elements(), 10.0f); std::vector<float> vec(s.elements(), 10.0f);
p.add_literal(migraphx::literal(s, vec)); p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape_float_test.onnx"); auto prog = optimize_onnx("const_of_shape_float_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -222,7 +241,7 @@ TEST_CASE(const_of_shape_int64_test) ...@@ -222,7 +241,7 @@ TEST_CASE(const_of_shape_int64_test)
std::vector<int64_t> vec(s.elements(), 10); std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec)); p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape_int64_test.onnx"); auto prog = optimize_onnx("const_of_shape_int64_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -235,13 +254,13 @@ TEST_CASE(const_of_shape_no_value_attr_test) ...@@ -235,13 +254,13 @@ TEST_CASE(const_of_shape_no_value_attr_test)
std::vector<float> vec(s.elements(), 0.0f); std::vector<float> vec(s.elements(), 0.0f);
p.add_literal(migraphx::literal(s, vec)); p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape_no_value_attr_test.onnx"); auto prog = optimize_onnx("const_of_shape_no_value_attr_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(conv_autopad_fail_test) TEST_CASE(conv_autopad_fail_test)
{ {
EXPECT(test::throws([&] { migraphx::parse_onnx("conv_autopad_fail_test.onnx"); })); EXPECT(test::throws([&] { optimize_onnx("conv_autopad_fail_test.onnx"); }));
} }
TEST_CASE(conv_bias_test) TEST_CASE(conv_bias_test)
...@@ -255,7 +274,7 @@ TEST_CASE(conv_bias_test) ...@@ -255,7 +274,7 @@ TEST_CASE(conv_bias_test)
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4); p.add_instruction(migraphx::op::add{}, l3, l4);
auto prog = migraphx::parse_onnx("conv_bias_test.onnx"); auto prog = optimize_onnx("conv_bias_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -278,7 +297,7 @@ TEST_CASE(conv_bn_relu_maxpool_test) ...@@ -278,7 +297,7 @@ TEST_CASE(conv_bn_relu_maxpool_test)
auto l7 = p.add_instruction(migraphx::op::relu{}, l6); auto l7 = p.add_instruction(migraphx::op::relu{}, l6);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
auto prog = migraphx::parse_onnx("conv_bn_relu_maxpool_test.onnx"); auto prog = optimize_onnx("conv_bn_relu_maxpool_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -295,7 +314,7 @@ TEST_CASE(conv_relu_maxpool_test) ...@@ -295,7 +314,7 @@ TEST_CASE(conv_relu_maxpool_test)
auto l6 = p.add_instruction(migraphx::op::relu{}, l5); auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto prog = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); auto prog = optimize_onnx("conv_relu_maxpool_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -320,7 +339,7 @@ TEST_CASE(conv_relu_maxpool_x2_test) ...@@ -320,7 +339,7 @@ TEST_CASE(conv_relu_maxpool_x2_test)
auto l13 = p.add_instruction(migraphx::op::relu{}, l12); auto l13 = p.add_instruction(migraphx::op::relu{}, l12);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
auto prog = migraphx::parse_onnx("conv_relu_maxpool_x2_test.onnx"); auto prog = optimize_onnx("conv_relu_maxpool_x2_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -331,7 +350,7 @@ TEST_CASE(cos_test) ...@@ -331,7 +350,7 @@ TEST_CASE(cos_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::cos{}, input); p.add_instruction(migraphx::op::cos{}, input);
auto prog = migraphx::parse_onnx("cos_test.onnx"); auto prog = optimize_onnx("cos_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -341,7 +360,7 @@ TEST_CASE(cosh_test) ...@@ -341,7 +360,7 @@ TEST_CASE(cosh_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
p.add_instruction(migraphx::op::cosh{}, input); p.add_instruction(migraphx::op::cosh{}, input);
auto prog = migraphx::parse_onnx("cosh_test.onnx"); auto prog = optimize_onnx("cosh_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -352,7 +371,7 @@ TEST_CASE(dropout_test) ...@@ -352,7 +371,7 @@ TEST_CASE(dropout_test)
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}); auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}});
p.add_instruction(migraphx::op::identity{}, input); p.add_instruction(migraphx::op::identity{}, input);
auto prog = migraphx::parse_onnx("dropout_test.onnx"); auto prog = optimize_onnx("dropout_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -363,7 +382,7 @@ TEST_CASE(elu_test) ...@@ -363,7 +382,7 @@ TEST_CASE(elu_test)
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::elu{0.01}, input); p.add_instruction(migraphx::op::elu{0.01}, input);
auto prog = migraphx::parse_onnx("elu_test.onnx"); auto prog = optimize_onnx("elu_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -374,7 +393,7 @@ TEST_CASE(erf_test) ...@@ -374,7 +393,7 @@ TEST_CASE(erf_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
p.add_instruction(migraphx::op::erf{}, input); p.add_instruction(migraphx::op::erf{}, input);
auto prog = migraphx::parse_onnx("erf_test.onnx"); auto prog = optimize_onnx("erf_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -384,7 +403,7 @@ TEST_CASE(exp_test) ...@@ -384,7 +403,7 @@ TEST_CASE(exp_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::exp{}, input); p.add_instruction(migraphx::op::exp{}, input);
auto prog = migraphx::parse_onnx("exp_test.onnx"); auto prog = optimize_onnx("exp_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -397,7 +416,7 @@ TEST_CASE(expand_test) ...@@ -397,7 +416,7 @@ TEST_CASE(expand_test)
p.add_literal(migraphx::literal(ss, {2, 3, 4, 5})); p.add_literal(migraphx::literal(ss, {2, 3, 4, 5}));
p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, param); p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, param);
auto prog = migraphx::parse_onnx("expand_test.onnx"); auto prog = optimize_onnx("expand_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -407,7 +426,7 @@ TEST_CASE(flatten_test) ...@@ -407,7 +426,7 @@ TEST_CASE(flatten_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::flatten{2}, l0); p.add_instruction(migraphx::op::flatten{2}, l0);
p.add_instruction(migraphx::op::flatten{1}, l0); p.add_instruction(migraphx::op::flatten{1}, l0);
auto prog = migraphx::parse_onnx("flatten_test.onnx"); auto prog = optimize_onnx("flatten_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -418,7 +437,7 @@ TEST_CASE(floor_test) ...@@ -418,7 +437,7 @@ TEST_CASE(floor_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::floor{}, input); p.add_instruction(migraphx::op::floor{}, input);
auto prog = migraphx::parse_onnx("floor_test.onnx"); auto prog = optimize_onnx("floor_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -430,7 +449,7 @@ TEST_CASE(gather_test) ...@@ -430,7 +449,7 @@ TEST_CASE(gather_test)
auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}}); auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
int axis = 1; int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, l0, l1); p.add_instruction(migraphx::op::gather{axis}, l0, l1);
auto prog = migraphx::parse_onnx("gather_test.onnx"); auto prog = optimize_onnx("gather_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -447,7 +466,7 @@ TEST_CASE(gemm_test) ...@@ -447,7 +466,7 @@ TEST_CASE(gemm_test)
auto alpha = 2.f; auto alpha = 2.f;
auto beta = 2.0f; auto beta = 2.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, t1, bl2); p.add_instruction(migraphx::op::dot{alpha, beta}, t0, t1, bl2);
auto prog = migraphx::parse_onnx("gemm_test.onnx"); auto prog = optimize_onnx("gemm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -462,7 +481,7 @@ TEST_CASE(gemm_ex_test) ...@@ -462,7 +481,7 @@ TEST_CASE(gemm_ex_test)
auto alpha = 0.5f; auto alpha = 0.5f;
auto beta = 0.8f; auto beta = 0.8f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, l2); p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, l2);
auto prog = migraphx::parse_onnx("gemm_ex_test.onnx"); auto prog = optimize_onnx("gemm_ex_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -479,7 +498,7 @@ TEST_CASE(gemm_ex_brcst_test) ...@@ -479,7 +498,7 @@ TEST_CASE(gemm_ex_brcst_test)
auto alpha = 0.5f; auto alpha = 0.5f;
auto beta = 0.8f; auto beta = 0.8f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, t2); p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, t2);
auto prog = migraphx::parse_onnx("gemm_ex_brcst_test.onnx"); auto prog = optimize_onnx("gemm_ex_brcst_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -493,7 +512,7 @@ TEST_CASE(globalavgpool_test) ...@@ -493,7 +512,7 @@ TEST_CASE(globalavgpool_test)
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input); p.add_instruction(op, input);
auto prog = migraphx::parse_onnx("globalavgpool_test.onnx"); auto prog = optimize_onnx("globalavgpool_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -507,7 +526,7 @@ TEST_CASE(globalmaxpool_test) ...@@ -507,7 +526,7 @@ TEST_CASE(globalmaxpool_test)
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input); p.add_instruction(op, input);
auto prog = migraphx::parse_onnx("globalmaxpool_test.onnx"); auto prog = optimize_onnx("globalmaxpool_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -520,7 +539,7 @@ TEST_CASE(group_conv_test) ...@@ -520,7 +539,7 @@ TEST_CASE(group_conv_test)
migraphx::op::convolution op; migraphx::op::convolution op;
op.group = 4; op.group = 4;
p.add_instruction(op, l0, l1); p.add_instruction(op, l0, l1);
auto prog = migraphx::parse_onnx("group_conv_test.onnx"); auto prog = optimize_onnx("group_conv_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -538,7 +557,7 @@ TEST_CASE(imagescaler_test) ...@@ -538,7 +557,7 @@ TEST_CASE(imagescaler_test)
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals); auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto prog = migraphx::parse_onnx("imagescaler_test.onnx"); auto prog = optimize_onnx("imagescaler_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -551,7 +570,7 @@ TEST_CASE(implicit_add_bcast_test) ...@@ -551,7 +570,7 @@ TEST_CASE(implicit_add_bcast_test)
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l0, l3); p.add_instruction(migraphx::op::add{}, l0, l3);
auto prog = migraphx::parse_onnx("implicit_add_bcast_test.onnx"); auto prog = optimize_onnx("implicit_add_bcast_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -564,7 +583,7 @@ TEST_CASE(implicit_pow_bcast_test) ...@@ -564,7 +583,7 @@ TEST_CASE(implicit_pow_bcast_test)
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::pow{}, l0, l3); p.add_instruction(migraphx::op::pow{}, l0, l3);
auto prog = migraphx::parse_onnx("implicit_pow_bcast_test.onnx"); auto prog = optimize_onnx("implicit_pow_bcast_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -577,7 +596,7 @@ TEST_CASE(implicit_sub_bcast_test) ...@@ -577,7 +596,7 @@ TEST_CASE(implicit_sub_bcast_test)
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l0, l3); p.add_instruction(migraphx::op::sub{}, l0, l3);
auto prog = migraphx::parse_onnx("implicit_sub_bcast_test.onnx"); auto prog = optimize_onnx("implicit_sub_bcast_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -590,7 +609,7 @@ TEST_CASE(initializer_not_an_input) ...@@ -590,7 +609,7 @@ TEST_CASE(initializer_not_an_input)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}});
p.add_instruction(migraphx::op::dot{}, l0, l1); p.add_instruction(migraphx::op::dot{}, l0, l1);
auto prog = migraphx::parse_onnx("initializer_not_an_input.onnx"); auto prog = optimize_onnx("initializer_not_an_input.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -602,7 +621,7 @@ TEST_CASE(leaky_relu_test) ...@@ -602,7 +621,7 @@ TEST_CASE(leaky_relu_test)
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {3}}); auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::leaky_relu{alpha}, l0); p.add_instruction(migraphx::op::leaky_relu{alpha}, l0);
auto prog = migraphx::parse_onnx("leaky_relu_test.onnx"); auto prog = optimize_onnx("leaky_relu_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -613,7 +632,7 @@ TEST_CASE(log_test) ...@@ -613,7 +632,7 @@ TEST_CASE(log_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::log{}, input); p.add_instruction(migraphx::op::log{}, input);
auto prog = migraphx::parse_onnx("log_test.onnx"); auto prog = optimize_onnx("log_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -623,7 +642,7 @@ TEST_CASE(logsoftmax_test) ...@@ -623,7 +642,7 @@ TEST_CASE(logsoftmax_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
int axis = 1; int axis = 1;
p.add_instruction(migraphx::op::logsoftmax{axis}, l0); p.add_instruction(migraphx::op::logsoftmax{axis}, l0);
auto prog = migraphx::parse_onnx("logsoftmax_test.onnx"); auto prog = optimize_onnx("logsoftmax_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -638,7 +657,7 @@ TEST_CASE(lrn_test) ...@@ -638,7 +657,7 @@ TEST_CASE(lrn_test)
op.beta = 0.75; op.beta = 0.75;
op.bias = 1.0; op.bias = 1.0;
p.add_instruction(op, l0); p.add_instruction(op, l0);
auto prog = migraphx::parse_onnx("lrn_test.onnx"); auto prog = optimize_onnx("lrn_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -652,7 +671,7 @@ TEST_CASE(matmul_bmbm_test) ...@@ -652,7 +671,7 @@ TEST_CASE(matmul_bmbm_test)
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 7, 8}}, l1); auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 7, 8}}, l1);
p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bl0, bl1); p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bl0, bl1);
auto prog = migraphx::parse_onnx("matmul_bmbm_test.onnx"); auto prog = optimize_onnx("matmul_bmbm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -667,7 +686,7 @@ TEST_CASE(matmul_bmv_test) ...@@ -667,7 +686,7 @@ TEST_CASE(matmul_bmv_test)
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, bsl1); auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, bsl1);
p.add_instruction(migraphx::op::squeeze{{2}}, res); p.add_instruction(migraphx::op::squeeze{{2}}, res);
auto prog = migraphx::parse_onnx("matmul_bmv_test.onnx"); auto prog = optimize_onnx("matmul_bmv_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -681,7 +700,7 @@ TEST_CASE(matmul_mv_test) ...@@ -681,7 +700,7 @@ TEST_CASE(matmul_mv_test)
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, sl1); auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, sl1);
p.add_instruction(migraphx::op::squeeze{{1}}, res); p.add_instruction(migraphx::op::squeeze{{1}}, res);
auto prog = migraphx::parse_onnx("matmul_mv_test.onnx"); auto prog = optimize_onnx("matmul_mv_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -698,7 +717,7 @@ TEST_CASE(matmul_vbm_test) ...@@ -698,7 +717,7 @@ TEST_CASE(matmul_vbm_test)
std::cout << "After Dot" << std::endl; std::cout << "After Dot" << std::endl;
p.add_instruction(migraphx::op::squeeze{{1}}, res); p.add_instruction(migraphx::op::squeeze{{1}}, res);
auto prog = migraphx::parse_onnx("matmul_vbm_test.onnx"); auto prog = optimize_onnx("matmul_vbm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -712,7 +731,7 @@ TEST_CASE(matmul_vm_test) ...@@ -712,7 +731,7 @@ TEST_CASE(matmul_vm_test)
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, l1); auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, l1);
p.add_instruction(migraphx::op::squeeze{{0}}, res); p.add_instruction(migraphx::op::squeeze{{0}}, res);
auto prog = migraphx::parse_onnx("matmul_vm_test.onnx"); auto prog = optimize_onnx("matmul_vm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -728,7 +747,7 @@ TEST_CASE(matmul_vv_test) ...@@ -728,7 +747,7 @@ TEST_CASE(matmul_vv_test)
auto sr0 = p.add_instruction(migraphx::op::squeeze{{0}}, res); auto sr0 = p.add_instruction(migraphx::op::squeeze{{0}}, res);
p.add_instruction(migraphx::op::squeeze{{0}}, sr0); p.add_instruction(migraphx::op::squeeze{{0}}, sr0);
auto prog = migraphx::parse_onnx("matmul_vv_test.onnx"); auto prog = optimize_onnx("matmul_vv_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -742,7 +761,7 @@ TEST_CASE(max_test) ...@@ -742,7 +761,7 @@ TEST_CASE(max_test)
auto l0 = p.add_instruction(migraphx::op::max{}, input0, input1); auto l0 = p.add_instruction(migraphx::op::max{}, input0, input1);
p.add_instruction(migraphx::op::max{}, l0, input2); p.add_instruction(migraphx::op::max{}, l0, input2);
migraphx::parse_onnx("max_test.onnx"); optimize_onnx("max_test.onnx");
} }
TEST_CASE(min_test) TEST_CASE(min_test)
...@@ -754,7 +773,7 @@ TEST_CASE(min_test) ...@@ -754,7 +773,7 @@ TEST_CASE(min_test)
auto l0 = p.add_instruction(migraphx::op::min{}, input0, input1); auto l0 = p.add_instruction(migraphx::op::min{}, input0, input1);
p.add_instruction(migraphx::op::min{}, l0, input2); p.add_instruction(migraphx::op::min{}, l0, input2);
migraphx::parse_onnx("min_test.onnx"); optimize_onnx("min_test.onnx");
} }
TEST_CASE(no_pad_test) TEST_CASE(no_pad_test)
...@@ -762,7 +781,7 @@ TEST_CASE(no_pad_test) ...@@ -762,7 +781,7 @@ TEST_CASE(no_pad_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_instruction(migraphx::op::identity{}, l0); p.add_instruction(migraphx::op::identity{}, l0);
auto prog = migraphx::parse_onnx("no_pad_test.onnx"); auto prog = optimize_onnx("no_pad_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -772,7 +791,7 @@ TEST_CASE(pad_test) ...@@ -772,7 +791,7 @@ TEST_CASE(pad_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}}, l0); p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}}, l0);
auto prog = migraphx::parse_onnx("pad_test.onnx"); auto prog = optimize_onnx("pad_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -784,7 +803,7 @@ TEST_CASE(pow_test) ...@@ -784,7 +803,7 @@ TEST_CASE(pow_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::pow{}, l0, l1); p.add_instruction(migraphx::op::pow{}, l0, l1);
auto prog = migraphx::parse_onnx("pow_test.onnx"); auto prog = optimize_onnx("pow_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -794,7 +813,7 @@ TEST_CASE(reducemax_test) ...@@ -794,7 +813,7 @@ TEST_CASE(reducemax_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_max{{2}}, l0); p.add_instruction(migraphx::op::reduce_max{{2}}, l0);
auto prog = migraphx::parse_onnx("reducemax_test.onnx"); auto prog = optimize_onnx("reducemax_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -805,7 +824,7 @@ TEST_CASE(reducemean_test) ...@@ -805,7 +824,7 @@ TEST_CASE(reducemean_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0); auto l1 = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = migraphx::parse_onnx("reducemean_test.onnx"); auto prog = optimize_onnx("reducemean_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -815,7 +834,7 @@ TEST_CASE(reducemean_keepdims_test) ...@@ -815,7 +834,7 @@ TEST_CASE(reducemean_keepdims_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_mean{{2}}, l0); p.add_instruction(migraphx::op::reduce_mean{{2}}, l0);
auto prog = migraphx::parse_onnx("reducemean_keepdims_test.onnx"); auto prog = optimize_onnx("reducemean_keepdims_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -826,7 +845,7 @@ TEST_CASE(reducemin_test) ...@@ -826,7 +845,7 @@ TEST_CASE(reducemin_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_min{{2, 3}}, l0); auto l1 = p.add_instruction(migraphx::op::reduce_min{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = migraphx::parse_onnx("reducemin_test.onnx"); auto prog = optimize_onnx("reducemin_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -837,7 +856,7 @@ TEST_CASE(reducesum_test) ...@@ -837,7 +856,7 @@ TEST_CASE(reducesum_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2}}, l0); auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2}}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, l1); p.add_instruction(migraphx::op::squeeze{{2}}, l1);
auto prog = migraphx::parse_onnx("reducesum_test.onnx"); auto prog = optimize_onnx("reducesum_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -848,7 +867,7 @@ TEST_CASE(reducesum_multiaxis_test) ...@@ -848,7 +867,7 @@ TEST_CASE(reducesum_multiaxis_test)
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0); auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = migraphx::parse_onnx("reducesum_multiaxis_test.onnx"); auto prog = optimize_onnx("reducesum_multiaxis_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -858,7 +877,7 @@ TEST_CASE(reducesum_keepdims_test) ...@@ -858,7 +877,7 @@ TEST_CASE(reducesum_keepdims_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0); p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0);
auto prog = migraphx::parse_onnx("reducesum_keepdims_test.onnx"); auto prog = optimize_onnx("reducesum_keepdims_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -874,7 +893,7 @@ TEST_CASE(reshape_test) ...@@ -874,7 +893,7 @@ TEST_CASE(reshape_test)
op.dims = reshape_dims; op.dims = reshape_dims;
p.add_instruction(op, l0); p.add_instruction(op, l0);
p.add_instruction(op, l0); p.add_instruction(op, l0);
auto prog = migraphx::parse_onnx("reshape_test.onnx"); auto prog = optimize_onnx("reshape_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -889,7 +908,7 @@ TEST_CASE(reshape_non_standard_test) ...@@ -889,7 +908,7 @@ TEST_CASE(reshape_non_standard_test)
auto tran_x = p.add_instruction(migraphx::op::transpose{{0, 2, 1}}, x); auto tran_x = p.add_instruction(migraphx::op::transpose{{0, 2, 1}}, x);
auto cont_x = p.add_instruction(migraphx::op::contiguous{}, tran_x); auto cont_x = p.add_instruction(migraphx::op::contiguous{}, tran_x);
p.add_instruction(migraphx::op::reshape{{4, 3, 2}}, cont_x); p.add_instruction(migraphx::op::reshape{{4, 3, 2}}, cont_x);
auto prog = migraphx::parse_onnx("reshape_non_standard_test.onnx"); auto prog = optimize_onnx("reshape_non_standard_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -900,7 +919,7 @@ TEST_CASE(round_test) ...@@ -900,7 +919,7 @@ TEST_CASE(round_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
p.add_instruction(migraphx::op::round{}, input); p.add_instruction(migraphx::op::round{}, input);
auto prog = migraphx::parse_onnx("round_test.onnx"); auto prog = optimize_onnx("round_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -911,7 +930,7 @@ TEST_CASE(shape_test) ...@@ -911,7 +930,7 @@ TEST_CASE(shape_test)
auto l0 = p.add_parameter("x", s); auto l0 = p.add_parameter("x", s);
migraphx::shape s_shape{migraphx::shape::int64_type, {4}}; migraphx::shape s_shape{migraphx::shape::int64_type, {4}};
p.add_literal(s_shape, l0->get_shape().lens()); p.add_literal(s_shape, l0->get_shape().lens());
auto prog = migraphx::parse_onnx("shape_test.onnx"); auto prog = optimize_onnx("shape_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -926,7 +945,7 @@ TEST_CASE(shape_gather_test) ...@@ -926,7 +945,7 @@ TEST_CASE(shape_gather_test)
auto l2 = p.add_literal(migraphx::literal{const_shape, {1}}); auto l2 = p.add_literal(migraphx::literal{const_shape, {1}});
int axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l2); p.add_instruction(migraphx::op::gather{axis}, l1, l2);
auto prog = migraphx::parse_onnx("shape_gather_test.onnx"); auto prog = optimize_onnx("shape_gather_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -937,7 +956,7 @@ TEST_CASE(sign_test) ...@@ -937,7 +956,7 @@ TEST_CASE(sign_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
p.add_instruction(migraphx::op::sign{}, input); p.add_instruction(migraphx::op::sign{}, input);
auto prog = migraphx::parse_onnx("sign_test.onnx"); auto prog = optimize_onnx("sign_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -947,7 +966,7 @@ TEST_CASE(sin_test) ...@@ -947,7 +966,7 @@ TEST_CASE(sin_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::sin{}, input); p.add_instruction(migraphx::op::sin{}, input);
auto prog = migraphx::parse_onnx("sin_test.onnx"); auto prog = optimize_onnx("sin_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -957,7 +976,7 @@ TEST_CASE(sinh_test) ...@@ -957,7 +976,7 @@ TEST_CASE(sinh_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::sinh{}, input); p.add_instruction(migraphx::op::sinh{}, input);
auto prog = migraphx::parse_onnx("sinh_test.onnx"); auto prog = optimize_onnx("sinh_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -967,7 +986,7 @@ TEST_CASE(slice_test) ...@@ -967,7 +986,7 @@ TEST_CASE(slice_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}});
p.add_instruction(migraphx::op::slice{{0, 1}, {1, 0}, {2, 2}}, l0); p.add_instruction(migraphx::op::slice{{0, 1}, {1, 0}, {2, 2}}, l0);
auto prog = migraphx::parse_onnx("slice_test.onnx"); auto prog = optimize_onnx("slice_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -977,7 +996,7 @@ TEST_CASE(softmax_test) ...@@ -977,7 +996,7 @@ TEST_CASE(softmax_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
p.add_instruction(migraphx::op::softmax{1}, l0); p.add_instruction(migraphx::op::softmax{1}, l0);
auto prog = migraphx::parse_onnx("softmax_test.onnx"); auto prog = optimize_onnx("softmax_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -988,7 +1007,7 @@ TEST_CASE(sqrt_test) ...@@ -988,7 +1007,7 @@ TEST_CASE(sqrt_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
p.add_instruction(migraphx::op::sqrt{}, input); p.add_instruction(migraphx::op::sqrt{}, input);
auto prog = migraphx::parse_onnx("sqrt_test.onnx"); auto prog = optimize_onnx("sqrt_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1001,7 +1020,7 @@ TEST_CASE(squeeze_unsqueeze_test) ...@@ -1001,7 +1020,7 @@ TEST_CASE(squeeze_unsqueeze_test)
p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}}); p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}});
auto l1 = p.add_instruction(migraphx::op::squeeze{squeeze_axes}, l0); auto l1 = p.add_instruction(migraphx::op::squeeze{squeeze_axes}, l0);
p.add_instruction(migraphx::op::unsqueeze{unsqueeze_axes}, l1); p.add_instruction(migraphx::op::unsqueeze{unsqueeze_axes}, l1);
auto prog = migraphx::parse_onnx("squeeze_unsqueeze_test.onnx"); auto prog = optimize_onnx("squeeze_unsqueeze_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1014,7 +1033,7 @@ TEST_CASE(sub_bcast_test) ...@@ -1014,7 +1033,7 @@ TEST_CASE(sub_bcast_test)
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::sub{}, l0, l2); p.add_instruction(migraphx::op::sub{}, l0, l2);
auto prog = migraphx::parse_onnx("sub_bcast_test.onnx"); auto prog = optimize_onnx("sub_bcast_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1027,7 +1046,7 @@ TEST_CASE(sub_scalar_test) ...@@ -1027,7 +1046,7 @@ TEST_CASE(sub_scalar_test)
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l0, m1); p.add_instruction(migraphx::op::sub{}, l0, m1);
auto prog = migraphx::parse_onnx("sub_scalar_test.onnx"); auto prog = optimize_onnx("sub_scalar_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1041,7 +1060,7 @@ TEST_CASE(sum_test) ...@@ -1041,7 +1060,7 @@ TEST_CASE(sum_test)
auto l0 = p.add_instruction(migraphx::op::add{}, input0, input1); auto l0 = p.add_instruction(migraphx::op::add{}, input0, input1);
p.add_instruction(migraphx::op::add{}, l0, input2); p.add_instruction(migraphx::op::add{}, l0, input2);
auto prog = migraphx::parse_onnx("sum_test.onnx"); auto prog = optimize_onnx("sum_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1051,7 +1070,7 @@ TEST_CASE(tan_test) ...@@ -1051,7 +1070,7 @@ TEST_CASE(tan_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::tan{}, input); p.add_instruction(migraphx::op::tan{}, input);
auto prog = migraphx::parse_onnx("tan_test.onnx"); auto prog = optimize_onnx("tan_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1061,7 +1080,7 @@ TEST_CASE(tanh_test) ...@@ -1061,7 +1080,7 @@ TEST_CASE(tanh_test)
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
p.add_instruction(migraphx::op::tanh{}, input); p.add_instruction(migraphx::op::tanh{}, input);
auto prog = migraphx::parse_onnx("tanh_test.onnx"); auto prog = optimize_onnx("tanh_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1073,7 +1092,7 @@ TEST_CASE(transpose_test) ...@@ -1073,7 +1092,7 @@ TEST_CASE(transpose_test)
std::vector<int64_t> perm{0, 3, 1, 2}; std::vector<int64_t> perm{0, 3, 1, 2};
p.add_instruction(migraphx::op::transpose{perm}, input); p.add_instruction(migraphx::op::transpose{perm}, input);
auto prog = migraphx::parse_onnx("transpose_test.onnx"); auto prog = optimize_onnx("transpose_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1099,7 +1118,7 @@ TEST_CASE(transpose_gather_test) ...@@ -1099,7 +1118,7 @@ TEST_CASE(transpose_gather_test)
p.add_instruction( p.add_instruction(
migraphx::op::gather{axis}, make_contiguous(tr_data), make_contiguous(tr_ind)); migraphx::op::gather{axis}, make_contiguous(tr_data), make_contiguous(tr_ind));
auto prog = migraphx::parse_onnx("transpose_gather_test.onnx"); auto prog = optimize_onnx("transpose_gather_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1111,7 +1130,7 @@ TEST_CASE(unknown_test) ...@@ -1111,7 +1130,7 @@ TEST_CASE(unknown_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::unknown{"Unknown"}, l0, l1); auto l2 = p.add_instruction(migraphx::op::unknown{"Unknown"}, l0, l1);
p.add_instruction(migraphx::op::unknown{"Unknown"}, l2); p.add_instruction(migraphx::op::unknown{"Unknown"}, l2);
auto prog = migraphx::parse_onnx("unknown_test.onnx"); auto prog = optimize_onnx("unknown_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1121,7 +1140,7 @@ TEST_CASE(variable_batch_test) ...@@ -1121,7 +1140,7 @@ TEST_CASE(variable_batch_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0); p.add_instruction(migraphx::op::identity{}, l0);
auto prog = migraphx::parse_onnx("variable_batch_test.onnx"); auto prog = optimize_onnx("variable_batch_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1132,7 +1151,7 @@ TEST_CASE(variable_batch_leq_zero_test) ...@@ -1132,7 +1151,7 @@ TEST_CASE(variable_batch_leq_zero_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::add{}, l0, l1); p.add_instruction(migraphx::op::add{}, l0, l1);
auto prog = migraphx::parse_onnx("variable_batch_leq_zero_test.onnx"); auto prog = optimize_onnx("variable_batch_leq_zero_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment