Unverified Commit 8d21fdc9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor to use make_op almost everywhere (#696)

* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold
parent b5633c27
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_reverse_3args_cell_output : verify_program<test_lstm_reverse_3args_cell_output> struct test_lstm_reverse_3args_cell_output : verify_program<test_lstm_reverse_3args_cell_output>
...@@ -26,15 +30,19 @@ struct test_lstm_reverse_3args_cell_output : verify_program<test_lstm_reverse_3a ...@@ -26,15 +30,19 @@ struct test_lstm_reverse_3args_cell_output : verify_program<test_lstm_reverse_3a
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto hs = mm->add_instruction( auto hs = mm->add_instruction(
migraphx::op::lstm{ migraphx::make_op(
hidden_size, "lstm",
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {{"hidden_size", hidden_size},
migraphx::op::rnn_direction::reverse, {"actv_func",
clip}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq, seq,
w, w,
r); r);
mm->add_instruction(migraphx::op::rnn_last_cell_output{}, hs); mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last> struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last>
...@@ -34,14 +38,18 @@ struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last> ...@@ -34,14 +38,18 @@ struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last>
auto ih = mm->add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto ic = mm->add_parameter("ic", ic_shape); auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape); auto pph = mm->add_parameter("pph", pph_shape);
auto und = mm->add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto output = mm->add_instruction( auto output = mm->add_instruction(
migraphx::op::lstm{ migraphx::make_op(
hidden_size, "lstm",
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {{"hidden_size", hidden_size},
migraphx::op::rnn_direction::reverse, {"actv_func",
clip}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq, seq,
w, w,
r, r,
...@@ -50,7 +58,7 @@ struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last> ...@@ -50,7 +58,7 @@ struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last>
ih, ih,
ic, ic,
pph); pph);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_three_outputs : verify_program<test_lstm_three_outputs> struct test_lstm_three_outputs : verify_program<test_lstm_three_outputs>
...@@ -26,16 +30,20 @@ struct test_lstm_three_outputs : verify_program<test_lstm_three_outputs> ...@@ -26,16 +30,20 @@ struct test_lstm_three_outputs : verify_program<test_lstm_three_outputs>
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto hs = mm->add_instruction( auto hs = mm->add_instruction(
migraphx::op::lstm{ migraphx::make_op(
hidden_size, "lstm",
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {{"hidden_size", hidden_size},
migraphx::op::rnn_direction::forward, {"actv_func",
clip}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq, seq,
w, w,
r); r);
auto last_hs = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs); auto last_hs = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
auto last_cell = mm->add_instruction(migraphx::op::rnn_last_cell_output{}, hs); auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_return({hs, last_hs, last_cell}); mm->add_return({hs, last_hs, last_cell});
return p; return p;
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_two_outputs : verify_program<test_lstm_two_outputs> struct test_lstm_two_outputs : verify_program<test_lstm_two_outputs>
...@@ -26,15 +30,19 @@ struct test_lstm_two_outputs : verify_program<test_lstm_two_outputs> ...@@ -26,15 +30,19 @@ struct test_lstm_two_outputs : verify_program<test_lstm_two_outputs>
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto hs = mm->add_instruction( auto hs = mm->add_instruction(
migraphx::op::lstm{ migraphx::make_op(
hidden_size, "lstm",
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {{"hidden_size", hidden_size},
migraphx::op::rnn_direction::forward, {"actv_func",
clip}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq, seq,
w, w,
r); r);
auto last_hs = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs); auto last_hs = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_return({hs, last_hs}); mm->add_return({hs, last_hs});
return p; return p;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_mul : verify_program<test_mul> struct test_mul : verify_program<test_mul>
{ {
...@@ -13,7 +13,7 @@ struct test_mul : verify_program<test_mul> ...@@ -13,7 +13,7 @@ struct test_mul : verify_program<test_mul>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::op::mul{}, x, y); mm->add_instruction(migraphx::make_op("mul"), x, y);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_mul_add : verify_program<test_mul_add> struct test_mul_add : verify_program<test_mul_add>
{ {
...@@ -12,13 +12,15 @@ struct test_mul_add : verify_program<test_mul_add> ...@@ -12,13 +12,15 @@ struct test_mul_add : verify_program<test_mul_add>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape bs{migraphx::shape::float_type, {3}}; migraphx::shape bs{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto a = mm->add_parameter("a", bs); auto a = mm->add_parameter("a", bs);
auto b = mm->add_parameter("b", bs); auto b = mm->add_parameter("b", bs);
auto ab = mm->add_instruction(migraphx::op::broadcast{1, s.lens()}, a); auto ab = mm->add_instruction(
auto bb = mm->add_instruction(migraphx::op::broadcast{1, s.lens()}, b); migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), a);
auto mul = mm->add_instruction(migraphx::op::mul{}, x, ab); auto bb = mm->add_instruction(
mm->add_instruction(migraphx::op::add{}, mul, bb); migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), b);
auto mul = mm->add_instruction(migraphx::make_op("mul"), x, ab);
mm->add_instruction(migraphx::make_op("add"), mul, bb);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_neg : verify_program<test_neg> struct test_neg : verify_program<test_neg>
{ {
...@@ -13,7 +13,7 @@ struct test_neg : verify_program<test_neg> ...@@ -13,7 +13,7 @@ struct test_neg : verify_program<test_neg>
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}};
auto input = mm->add_parameter("x", s); auto input = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::neg{}, input); mm->add_instruction(migraphx::make_op("neg"), input);
return p; return p;
}; };
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_pad : verify_program<test_pad> struct test_pad : verify_program<test_pad>
{ {
...@@ -16,10 +16,10 @@ struct test_pad : verify_program<test_pad> ...@@ -16,10 +16,10 @@ struct test_pad : verify_program<test_pad>
std::vector<int64_t> pads2 = {1, 1, 1, 1, 0, 0, 0, 0}; std::vector<int64_t> pads2 = {1, 1, 1, 1, 0, 0, 0, 0};
std::vector<int64_t> pads3 = {1, 0, 1, 0, 1, 0, 2, 0}; std::vector<int64_t> pads3 = {1, 0, 1, 0, 1, 0, 2, 0};
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
mm->add_instruction(migraphx::op::pad{pads0}, l0); mm->add_instruction(migraphx::make_op("pad", {{"pads", pads0}}), l0);
mm->add_instruction(migraphx::op::pad{pads1}, l0); mm->add_instruction(migraphx::make_op("pad", {{"pads", pads1}}), l0);
mm->add_instruction(migraphx::op::pad{pads2}, l0); mm->add_instruction(migraphx::make_op("pad", {{"pads", pads2}}), l0);
mm->add_instruction(migraphx::op::pad{pads3}, l0); mm->add_instruction(migraphx::make_op("pad", {{"pads", pads3}}), l0);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_pad_transposed : verify_program<test_pad_transposed> struct test_pad_transposed : verify_program<test_pad_transposed>
{ {
...@@ -12,8 +12,8 @@ struct test_pad_transposed : verify_program<test_pad_transposed> ...@@ -12,8 +12,8 @@ struct test_pad_transposed : verify_program<test_pad_transposed>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {1, 224, 224, 3}}; migraphx::shape s{migraphx::shape::int32_type, {1, 224, 224, 3}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, x); auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), x);
mm->add_instruction(migraphx::op::pad{{0, 0, 2, 2, 0, 0, 3, 3}}, t); mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 2, 2, 0, 0, 3, 3}}}), t);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_pow : verify_program<test_pow> struct test_pow : verify_program<test_pow>
{ {
...@@ -14,7 +14,7 @@ struct test_pow : verify_program<test_pow> ...@@ -14,7 +14,7 @@ struct test_pow : verify_program<test_pow>
std::vector<float> vec_e(s.elements(), 2.0f); std::vector<float> vec_e(s.elements(), 2.0f);
auto b = mm->add_parameter("x", s); auto b = mm->add_parameter("x", s);
auto e = mm->add_literal(migraphx::literal(s, vec_e)); auto e = mm->add_literal(migraphx::literal(s, vec_e));
mm->add_instruction(migraphx::op::pow{}, b, e); mm->add_instruction(migraphx::make_op("pow"), b, e);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_prelu_brcst : verify_program<test_prelu_brcst> struct test_prelu_brcst : verify_program<test_prelu_brcst>
{ {
...@@ -13,7 +13,7 @@ struct test_prelu_brcst : verify_program<test_prelu_brcst> ...@@ -13,7 +13,7 @@ struct test_prelu_brcst : verify_program<test_prelu_brcst>
migraphx::shape s{migraphx::shape::float_type, {6}}; migraphx::shape s{migraphx::shape::float_type, {6}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto slp = mm->add_parameter("slp", s); auto slp = mm->add_parameter("slp", s);
auto r = mm->add_instruction(migraphx::op::prelu{}, x, slp); auto r = mm->add_instruction(migraphx::make_op("prelu"), x, slp);
mm->add_return({r}); mm->add_return({r});
return p; return p;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_recip : verify_program<test_recip> struct test_recip : verify_program<test_recip>
{ {
...@@ -12,7 +12,7 @@ struct test_recip : verify_program<test_recip> ...@@ -12,7 +12,7 @@ struct test_recip : verify_program<test_recip>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {3}}; migraphx::shape s{migraphx::shape::double_type, {3}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::recip{}, x); mm->add_instruction(migraphx::make_op("recip"), x);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_relu_lrn : verify_program<test_relu_lrn> struct test_relu_lrn : verify_program<test_relu_lrn>
{ {
...@@ -11,8 +11,11 @@ struct test_relu_lrn : verify_program<test_relu_lrn> ...@@ -11,8 +11,11 @@ struct test_relu_lrn : verify_program<test_relu_lrn>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}});
auto y = mm->add_instruction(migraphx::op::relu{}, x); auto y = mm->add_instruction(migraphx::make_op("relu"), x);
mm->add_instruction(migraphx::op::lrn{0.0001, 0.75, 1.0, 5}, y); mm->add_instruction(
migraphx::make_op("lrn",
{{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}),
y);
return p; return p;
} }
}; };
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_rnn_3args : verify_program<test_rnn_3args> struct test_rnn_3args : verify_program<test_rnn_3args>
...@@ -25,13 +29,18 @@ struct test_rnn_3args : verify_program<test_rnn_3args> ...@@ -25,13 +29,18 @@ struct test_rnn_3args : verify_program<test_rnn_3args>
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
mm->add_instruction(migraphx::op::rnn{hidden_size, mm->add_instruction(
{migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::make_op(
migraphx::op::rnn_direction::reverse, "rnn",
clip}, {{"hidden_size", hidden_size},
seq, {"actv_func",
w, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
r); migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_rnn_4args : verify_program<test_rnn_4args> struct test_rnn_4args : verify_program<test_rnn_4args>
...@@ -27,14 +31,19 @@ struct test_rnn_4args : verify_program<test_rnn_4args> ...@@ -27,14 +31,19 @@ struct test_rnn_4args : verify_program<test_rnn_4args>
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
mm->add_instruction(migraphx::op::rnn{hidden_size, mm->add_instruction(
{migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::make_op(
migraphx::op::rnn_direction::reverse, "rnn",
clip}, {{"hidden_size", hidden_size},
seq, {"actv_func",
w, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
r, migraphx::make_op("tanh")})},
bias); {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
bias);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_rnn_5args : verify_program<test_rnn_5args> struct test_rnn_5args : verify_program<test_rnn_5args>
...@@ -26,19 +30,23 @@ struct test_rnn_5args : verify_program<test_rnn_5args> ...@@ -26,19 +30,23 @@ struct test_rnn_5args : verify_program<test_rnn_5args>
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto und = mm->add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto output = auto output = mm->add_instruction(
mm->add_instruction(migraphx::op::rnn{hidden_size, migraphx::make_op(
{migraphx::op::tanh{}, migraphx::op::tanh{}}, "rnn",
migraphx::op::rnn_direction::forward, {{"hidden_size", hidden_size},
clip}, {"actv_func",
seq, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
w, migraphx::make_op("tanh")})},
r, {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
bias, {"clip", clip}}),
und); seq,
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output); w,
r,
bias,
und);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args> struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
...@@ -23,18 +27,22 @@ struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args> ...@@ -23,18 +27,22 @@ struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto output = auto output = mm->add_instruction(
mm->add_instruction(migraphx::op::rnn{hidden_size, migraphx::make_op(
{migraphx::op::tanh{}, migraphx::op::tanh{}}, "rnn",
migraphx::op::rnn_direction::bidirectional, {{"hidden_size", hidden_size},
clip}, {"actv_func",
seq, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
w, migraphx::make_op("tanh")})},
r); {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output); {"clip", clip}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional> struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional>
...@@ -28,20 +32,24 @@ struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional> ...@@ -28,20 +32,24 @@ struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional>
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto output = auto output = mm->add_instruction(
mm->add_instruction(migraphx::op::rnn{hidden_size, migraphx::make_op(
{migraphx::op::tanh{}, migraphx::op::tanh{}}, "rnn",
migraphx::op::rnn_direction::bidirectional, {{"hidden_size", hidden_size},
clip}, {"actv_func",
seq, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
w, migraphx::make_op("tanh")})},
r, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
bias, {"clip", clip}}),
und, seq,
ih); w,
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output); r,
bias,
und,
ih);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10> struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
...@@ -23,24 +27,28 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10> ...@@ -23,24 +27,28 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto output = auto output = mm->add_instruction(
mm->add_instruction(migraphx::op::rnn{hidden_size, migraphx::make_op(
{migraphx::op::tanh{}, migraphx::op::tanh{}}, "rnn",
migraphx::op::rnn_direction::bidirectional, {{"hidden_size", hidden_size},
clip}, {"actv_func",
seq, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
w, migraphx::make_op("tanh")})},
r, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
bias, {"clip", clip}}),
und, seq,
ih); w,
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output); r,
bias,
und,
ih);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_rnn_forward : verify_program<test_rnn_forward> struct test_rnn_forward : verify_program<test_rnn_forward>
...@@ -28,20 +32,24 @@ struct test_rnn_forward : verify_program<test_rnn_forward> ...@@ -28,20 +32,24 @@ struct test_rnn_forward : verify_program<test_rnn_forward>
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto hs = auto hs = mm->add_instruction(
mm->add_instruction(migraphx::op::rnn{hidden_size, migraphx::make_op(
{migraphx::op::tanh{}, migraphx::op::tanh{}}, "rnn",
migraphx::op::rnn_direction::forward, {{"hidden_size", hidden_size},
clip}, {"actv_func",
seq, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
w, migraphx::make_op("tanh")})},
r, {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
bias, {"clip", clip}}),
und, seq,
ih); w,
auto lho = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs); r,
bias,
und,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_return({hs, lho}); mm->add_return({hs, lho});
return p; return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment