Unverified Commit 8d21fdc9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor to use make_op almost everywhere (#696)

* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold
parent b5633c27
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_less_brcst : verify_program<test_less_brcst> struct test_less_brcst : verify_program<test_less_brcst>
{ {
...@@ -14,8 +14,9 @@ struct test_less_brcst : verify_program<test_less_brcst> ...@@ -14,8 +14,9 @@ struct test_less_brcst : verify_program<test_less_brcst>
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = mm->add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1); auto bl1 = mm->add_instruction(
auto r = mm->add_instruction(migraphx::op::less{}, l0, bl1); migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1);
auto r = mm->add_instruction(migraphx::make_op("less"), l0, bl1);
mm->add_return({r}); mm->add_return({r});
return p; return p;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_literals : verify_program<test_literals> struct test_literals : verify_program<test_literals>
{ {
...@@ -14,8 +14,8 @@ struct test_literals : verify_program<test_literals> ...@@ -14,8 +14,8 @@ struct test_literals : verify_program<test_literals>
generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}})); generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}));
auto weights = mm->add_literal( auto weights = mm->add_literal(
generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}})); generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}));
auto conv = mm->add_instruction(migraphx::op::convolution{}, input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_instruction(migraphx::op::relu{}, conv); mm->add_instruction(migraphx::make_op("relu"), conv);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_log : verify_program<test_log> struct test_log : verify_program<test_log>
{ {
...@@ -11,8 +11,8 @@ struct test_log : verify_program<test_log> ...@@ -11,8 +11,8 @@ struct test_log : verify_program<test_log>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}}; migraphx::shape s{migraphx::shape::float_type, {6}};
auto x = mm->add_instruction(migraphx::op::abs{}, mm->add_parameter("x", s)); auto x = mm->add_instruction(migraphx::make_op("abs"), mm->add_parameter("x", s));
mm->add_instruction(migraphx::op::log{}, x); mm->add_instruction(migraphx::make_op("log"), x);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
template <int Axis, migraphx::shape::type_t T> template <int Axis, migraphx::shape::type_t T>
struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>> struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>>
...@@ -13,7 +13,7 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>> ...@@ -13,7 +13,7 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{T, {10, 4, 2080, 6}}; migraphx::shape s{T, {10, 4, 2080, 6}};
auto param = mm->add_parameter("0", s); auto param = mm->add_parameter("0", s);
mm->add_instruction(migraphx::op::logsoftmax{Axis}, param); mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", Axis}}), param);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_bidirct_3args : verify_program<test_lstm_bidirct_3args> struct test_lstm_bidirct_3args : verify_program<test_lstm_bidirct_3args>
...@@ -25,13 +29,18 @@ struct test_lstm_bidirct_3args : verify_program<test_lstm_bidirct_3args> ...@@ -25,13 +29,18 @@ struct test_lstm_bidirct_3args : verify_program<test_lstm_bidirct_3args>
auto seq = mm->add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
mm->add_instruction(migraphx::op::lstm{hidden_size, mm->add_instruction(
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::make_op(
migraphx::op::rnn_direction::bidirectional, "lstm",
clip}, {{"hidden_size", hidden_size},
seq, {"actv_func",
w, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
r); migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_bidirct_3args_und : verify_program<test_lstm_bidirct_3args_und> struct test_lstm_bidirct_3args_und : verify_program<test_lstm_bidirct_3args_und>
...@@ -25,12 +29,17 @@ struct test_lstm_bidirct_3args_und : verify_program<test_lstm_bidirct_3args_und> ...@@ -25,12 +29,17 @@ struct test_lstm_bidirct_3args_und : verify_program<test_lstm_bidirct_3args_und>
auto seq = mm->add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction( mm->add_instruction(
migraphx::op::gru{hidden_size, migraphx::make_op(
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, "gru",
migraphx::op::rnn_direction::bidirectional, {{"hidden_size", hidden_size},
clip}, {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq, seq,
w, w,
r, r,
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_bidirct_default_actv : verify_program<test_lstm_bidirct_default_actv> struct test_lstm_bidirct_default_actv : verify_program<test_lstm_bidirct_default_actv>
...@@ -26,7 +30,12 @@ struct test_lstm_bidirct_default_actv : verify_program<test_lstm_bidirct_default ...@@ -26,7 +30,12 @@ struct test_lstm_bidirct_default_actv : verify_program<test_lstm_bidirct_default
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
mm->add_instruction( mm->add_instruction(
migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq, seq,
w, w,
r); r);
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_default_actv1> struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_default_actv1>
...@@ -34,16 +38,21 @@ struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_defaul ...@@ -34,16 +38,21 @@ struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_defaul
std::vector<int> sl_data(batch_size, 2); std::vector<int> sl_data(batch_size, 2);
auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data}); auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data});
mm->add_instruction(migraphx::op::lstm{hidden_size, mm->add_instruction(
{migraphx::op::sigmoid{}}, migraphx::make_op(
migraphx::op::rnn_direction::bidirectional, "lstm",
clip}, {{"hidden_size", hidden_size},
seq, {"actv_func",
w, migraphx::to_value(
r, std::vector<migraphx::operation>{migraphx::make_op("sigmoid")})},
bias, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
sql, {"clip", clip}}),
ih); seq,
w,
r,
bias,
sql,
ih);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_default_actv2> struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_default_actv2>
...@@ -30,18 +34,23 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul ...@@ -30,18 +34,23 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(migraphx::op::lstm{hidden_size, mm->add_instruction(
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::make_op(
migraphx::op::rnn_direction::bidirectional, "lstm",
clip}, {{"hidden_size", hidden_size},
seq, {"actv_func",
w, migraphx::to_value(std::vector<migraphx::operation>{
r, migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})},
bias, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
und, {"clip", clip}}),
ih); seq,
w,
r,
bias,
und,
ih);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_bidirct_hs : verify_program<test_lstm_bidirct_hs> struct test_lstm_bidirct_hs : verify_program<test_lstm_bidirct_hs>
...@@ -34,16 +38,21 @@ struct test_lstm_bidirct_hs : verify_program<test_lstm_bidirct_hs> ...@@ -34,16 +38,21 @@ struct test_lstm_bidirct_hs : verify_program<test_lstm_bidirct_hs>
std::vector<int> sl_data{3, 2}; std::vector<int> sl_data{3, 2};
auto sql = mm->add_literal(migraphx::literal{migraphx::literal{sl_shape, sl_data}}); auto sql = mm->add_literal(migraphx::literal{migraphx::literal{sl_shape, sl_data}});
mm->add_instruction(migraphx::op::lstm{hidden_size, mm->add_instruction(
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::make_op(
migraphx::op::rnn_direction::bidirectional, "lstm",
clip}, {{"hidden_size", hidden_size},
seq, {"actv_func",
w, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
r, migraphx::make_op("tanh")})},
bias, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
sql, {"clip", clip}}),
ih); seq,
w,
r,
bias,
sql,
ih);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_bidirct_last : verify_program<test_lstm_bidirct_last> struct test_lstm_bidirct_last : verify_program<test_lstm_bidirct_last>
...@@ -34,14 +38,18 @@ struct test_lstm_bidirct_last : verify_program<test_lstm_bidirct_last> ...@@ -34,14 +38,18 @@ struct test_lstm_bidirct_last : verify_program<test_lstm_bidirct_last>
auto ih = mm->add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto ic = mm->add_parameter("ic", ic_shape); auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape); auto pph = mm->add_parameter("pph", pph_shape);
auto und = mm->add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto output = mm->add_instruction( auto output = mm->add_instruction(
migraphx::op::lstm{ migraphx::make_op(
hidden_size, "lstm",
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {{"hidden_size", hidden_size},
migraphx::op::rnn_direction::bidirectional, {"actv_func",
clip}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq, seq,
w, w,
r, r,
...@@ -50,7 +58,7 @@ struct test_lstm_bidirct_last : verify_program<test_lstm_bidirct_last> ...@@ -50,7 +58,7 @@ struct test_lstm_bidirct_last : verify_program<test_lstm_bidirct_last>
ih, ih,
ic, ic,
pph); pph);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_bidirct_seq1 : verify_program<test_lstm_bidirct_seq1> struct test_lstm_bidirct_seq1 : verify_program<test_lstm_bidirct_seq1>
...@@ -25,13 +29,18 @@ struct test_lstm_bidirct_seq1 : verify_program<test_lstm_bidirct_seq1> ...@@ -25,13 +29,18 @@ struct test_lstm_bidirct_seq1 : verify_program<test_lstm_bidirct_seq1>
auto seq = mm->add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
mm->add_instruction(migraphx::op::lstm{hidden_size, mm->add_instruction(
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::make_op(
migraphx::op::rnn_direction::bidirectional, "lstm",
clip}, {{"hidden_size", hidden_size},
seq, {"actv_func",
w, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
r); migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_forward_3args : verify_program<test_lstm_forward_3args> struct test_lstm_forward_3args : verify_program<test_lstm_forward_3args>
...@@ -26,11 +30,15 @@ struct test_lstm_forward_3args : verify_program<test_lstm_forward_3args> ...@@ -26,11 +30,15 @@ struct test_lstm_forward_3args : verify_program<test_lstm_forward_3args>
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
mm->add_instruction( mm->add_instruction(
migraphx::op::lstm{ migraphx::make_op(
hidden_size, "lstm",
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {{"hidden_size", hidden_size},
migraphx::op::rnn_direction::forward, {"actv_func",
clip}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq, seq,
w, w,
r); r);
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_forward_3args_und : verify_program<test_lstm_forward_3args_und> struct test_lstm_forward_3args_und : verify_program<test_lstm_forward_3args_und>
...@@ -25,13 +29,17 @@ struct test_lstm_forward_3args_und : verify_program<test_lstm_forward_3args_und> ...@@ -25,13 +29,17 @@ struct test_lstm_forward_3args_und : verify_program<test_lstm_forward_3args_und>
auto seq = mm->add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction( mm->add_instruction(
migraphx::op::lstm{ migraphx::make_op(
hidden_size, "lstm",
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {{"hidden_size", hidden_size},
migraphx::op::rnn_direction::forward, {"actv_func",
clip}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq, seq,
w, w,
r, r,
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_forward_default_actv : verify_program<test_lstm_forward_default_actv> struct test_lstm_forward_default_actv : verify_program<test_lstm_forward_default_actv>
...@@ -26,7 +30,12 @@ struct test_lstm_forward_default_actv : verify_program<test_lstm_forward_default ...@@ -26,7 +30,12 @@ struct test_lstm_forward_default_actv : verify_program<test_lstm_forward_default
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
mm->add_instruction( mm->add_instruction(
migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq, seq,
w, w,
r); r);
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_forward_default_actv1 : verify_program<test_lstm_forward_default_actv1> struct test_lstm_forward_default_actv1 : verify_program<test_lstm_forward_default_actv1>
...@@ -30,11 +34,17 @@ struct test_lstm_forward_default_actv1 : verify_program<test_lstm_forward_defaul ...@@ -30,11 +34,17 @@ struct test_lstm_forward_default_actv1 : verify_program<test_lstm_forward_defaul
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape); auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction( mm->add_instruction(
migraphx::op::lstm{ migraphx::make_op(
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip}, "lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(
std::vector<migraphx::operation>{migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq, seq,
w, w,
r, r,
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_forward_hs : verify_program<test_lstm_forward_hs> struct test_lstm_forward_hs : verify_program<test_lstm_forward_hs>
...@@ -34,14 +38,18 @@ struct test_lstm_forward_hs : verify_program<test_lstm_forward_hs> ...@@ -34,14 +38,18 @@ struct test_lstm_forward_hs : verify_program<test_lstm_forward_hs>
auto ih = mm->add_parameter("ih", ih_shape); auto ih = mm->add_parameter("ih", ih_shape);
auto ic = mm->add_parameter("ic", ic_shape); auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape); auto pph = mm->add_parameter("pph", pph_shape);
auto und = mm->add_instruction(migraphx::op::undefined{}); auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction( mm->add_instruction(
migraphx::op::lstm{ migraphx::make_op(
hidden_size, "lstm",
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {{"hidden_size", hidden_size},
migraphx::op::rnn_direction::forward, {"actv_func",
clip}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq, seq,
w, w,
r, r,
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_forward_last : verify_program<test_lstm_forward_last> struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
...@@ -38,11 +42,15 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last> ...@@ -38,11 +42,15 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
auto pph = mm->add_parameter("pph", pph_shape); auto pph = mm->add_parameter("pph", pph_shape);
auto output = mm->add_instruction( auto output = mm->add_instruction(
migraphx::op::lstm{ migraphx::make_op(
hidden_size, "lstm",
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {{"hidden_size", hidden_size},
migraphx::op::rnn_direction::forward, {"actv_func",
clip}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq, seq,
w, w,
r, r,
...@@ -51,7 +59,7 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last> ...@@ -51,7 +59,7 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
ih, ih,
ic, ic,
pph); pph);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output, len); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output, len);
return p; return p;
} }
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_forward_seq1 : verify_program<test_lstm_forward_seq1> struct test_lstm_forward_seq1 : verify_program<test_lstm_forward_seq1>
...@@ -26,11 +30,15 @@ struct test_lstm_forward_seq1 : verify_program<test_lstm_forward_seq1> ...@@ -26,11 +30,15 @@ struct test_lstm_forward_seq1 : verify_program<test_lstm_forward_seq1>
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
mm->add_instruction( mm->add_instruction(
migraphx::op::lstm{ migraphx::make_op(
hidden_size, "lstm",
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {{"hidden_size", hidden_size},
migraphx::op::rnn_direction::forward, {"actv_func",
clip}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq, seq,
w, w,
r); r);
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_lstm_reverse_3args : verify_program<test_lstm_reverse_3args> struct test_lstm_reverse_3args : verify_program<test_lstm_reverse_3args>
...@@ -26,11 +30,15 @@ struct test_lstm_reverse_3args : verify_program<test_lstm_reverse_3args> ...@@ -26,11 +30,15 @@ struct test_lstm_reverse_3args : verify_program<test_lstm_reverse_3args>
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape); auto r = mm->add_parameter("r", r_shape);
mm->add_instruction( mm->add_instruction(
migraphx::op::lstm{ migraphx::make_op(
hidden_size, "lstm",
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {{"hidden_size", hidden_size},
migraphx::op::rnn_direction::reverse, {"actv_func",
clip}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq, seq,
w, w,
r); r);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment