Unverified Commit 8d21fdc9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor to use make_op almost everywhere (#696)

* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold
parent b5633c27
......@@ -2,6 +2,10 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp>
struct test_lstm_reverse_3args_cell_output : verify_program<test_lstm_reverse_3args_cell_output>
......@@ -26,15 +30,19 @@ struct test_lstm_reverse_3args_cell_output : verify_program<test_lstm_reverse_3a
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto hs = mm->add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r);
mm->add_instruction(migraphx::op::rnn_last_cell_output{}, hs);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
return p;
}
......
......@@ -2,6 +2,10 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp>
struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last>
......@@ -34,14 +38,18 @@ struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last>
auto ih = mm->add_parameter("ih", ih_shape);
auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto output = mm->add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
......@@ -50,7 +58,7 @@ struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last>
ih,
ic,
pph);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
return p;
}
......
......@@ -2,6 +2,10 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp>
struct test_lstm_three_outputs : verify_program<test_lstm_three_outputs>
......@@ -26,16 +30,20 @@ struct test_lstm_three_outputs : verify_program<test_lstm_three_outputs>
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto hs = mm->add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r);
auto last_hs = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
auto last_cell = mm->add_instruction(migraphx::op::rnn_last_cell_output{}, hs);
auto last_hs = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_return({hs, last_hs, last_cell});
return p;
......
......@@ -2,6 +2,10 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp>
struct test_lstm_two_outputs : verify_program<test_lstm_two_outputs>
......@@ -26,15 +30,19 @@ struct test_lstm_two_outputs : verify_program<test_lstm_two_outputs>
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto hs = mm->add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r);
auto last_hs = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
auto last_hs = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_return({hs, last_hs});
return p;
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_mul : verify_program<test_mul>
{
......@@ -13,7 +13,7 @@ struct test_mul : verify_program<test_mul>
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::op::mul{}, x, y);
mm->add_instruction(migraphx::make_op("mul"), x, y);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_mul_add : verify_program<test_mul_add>
{
......@@ -12,13 +12,15 @@ struct test_mul_add : verify_program<test_mul_add>
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape bs{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", s);
auto a = mm->add_parameter("a", bs);
auto b = mm->add_parameter("b", bs);
auto ab = mm->add_instruction(migraphx::op::broadcast{1, s.lens()}, a);
auto bb = mm->add_instruction(migraphx::op::broadcast{1, s.lens()}, b);
auto mul = mm->add_instruction(migraphx::op::mul{}, x, ab);
mm->add_instruction(migraphx::op::add{}, mul, bb);
auto x = mm->add_parameter("x", s);
auto a = mm->add_parameter("a", bs);
auto b = mm->add_parameter("b", bs);
auto ab = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), a);
auto bb = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), b);
auto mul = mm->add_instruction(migraphx::make_op("mul"), x, ab);
mm->add_instruction(migraphx::make_op("add"), mul, bb);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_neg : verify_program<test_neg>
{
......@@ -13,7 +13,7 @@ struct test_neg : verify_program<test_neg>
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}};
auto input = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::neg{}, input);
mm->add_instruction(migraphx::make_op("neg"), input);
return p;
};
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_pad : verify_program<test_pad>
{
......@@ -16,10 +16,10 @@ struct test_pad : verify_program<test_pad>
std::vector<int64_t> pads2 = {1, 1, 1, 1, 0, 0, 0, 0};
std::vector<int64_t> pads3 = {1, 0, 1, 0, 1, 0, 2, 0};
auto l0 = mm->add_parameter("x", s0);
mm->add_instruction(migraphx::op::pad{pads0}, l0);
mm->add_instruction(migraphx::op::pad{pads1}, l0);
mm->add_instruction(migraphx::op::pad{pads2}, l0);
mm->add_instruction(migraphx::op::pad{pads3}, l0);
mm->add_instruction(migraphx::make_op("pad", {{"pads", pads0}}), l0);
mm->add_instruction(migraphx::make_op("pad", {{"pads", pads1}}), l0);
mm->add_instruction(migraphx::make_op("pad", {{"pads", pads2}}), l0);
mm->add_instruction(migraphx::make_op("pad", {{"pads", pads3}}), l0);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_pad_transposed : verify_program<test_pad_transposed>
{
......@@ -12,8 +12,8 @@ struct test_pad_transposed : verify_program<test_pad_transposed>
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {1, 224, 224, 3}};
auto x = mm->add_parameter("x", s);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, x);
mm->add_instruction(migraphx::op::pad{{0, 0, 2, 2, 0, 0, 3, 3}}, t);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), x);
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 2, 2, 0, 0, 3, 3}}}), t);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_pow : verify_program<test_pow>
{
......@@ -14,7 +14,7 @@ struct test_pow : verify_program<test_pow>
std::vector<float> vec_e(s.elements(), 2.0f);
auto b = mm->add_parameter("x", s);
auto e = mm->add_literal(migraphx::literal(s, vec_e));
mm->add_instruction(migraphx::op::pow{}, b, e);
mm->add_instruction(migraphx::make_op("pow"), b, e);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_prelu_brcst : verify_program<test_prelu_brcst>
{
......@@ -13,7 +13,7 @@ struct test_prelu_brcst : verify_program<test_prelu_brcst>
migraphx::shape s{migraphx::shape::float_type, {6}};
auto x = mm->add_parameter("x", s);
auto slp = mm->add_parameter("slp", s);
auto r = mm->add_instruction(migraphx::op::prelu{}, x, slp);
auto r = mm->add_instruction(migraphx::make_op("prelu"), x, slp);
mm->add_return({r});
return p;
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_recip : verify_program<test_recip>
{
......@@ -12,7 +12,7 @@ struct test_recip : verify_program<test_recip>
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {3}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::recip{}, x);
mm->add_instruction(migraphx::make_op("recip"), x);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_relu_lrn : verify_program<test_relu_lrn>
{
......@@ -11,8 +11,11 @@ struct test_relu_lrn : verify_program<test_relu_lrn>
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}});
auto y = mm->add_instruction(migraphx::op::relu{}, x);
mm->add_instruction(migraphx::op::lrn{0.0001, 0.75, 1.0, 5}, y);
auto y = mm->add_instruction(migraphx::make_op("relu"), x);
mm->add_instruction(
migraphx::make_op("lrn",
{{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}),
y);
return p;
}
};
......@@ -2,6 +2,10 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp>
struct test_rnn_3args : verify_program<test_rnn_3args>
......@@ -25,13 +29,18 @@ struct test_rnn_3args : verify_program<test_rnn_3args>
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
mm->add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r);
mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r);
return p;
}
......
......@@ -2,6 +2,10 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp>
struct test_rnn_4args : verify_program<test_rnn_4args>
......@@ -27,14 +31,19 @@ struct test_rnn_4args : verify_program<test_rnn_4args>
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
mm->add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias);
mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
bias);
return p;
}
......
......@@ -2,6 +2,10 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp>
struct test_rnn_5args : verify_program<test_rnn_5args>
......@@ -26,19 +30,23 @@ struct test_rnn_5args : verify_program<test_rnn_5args>
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto output =
mm->add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto output = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
und);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
return p;
}
......
......@@ -2,6 +2,10 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp>
struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
......@@ -23,18 +27,22 @@ struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto output =
mm->add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output);
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto output = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
return p;
}
......
......@@ -2,6 +2,10 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp>
struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional>
......@@ -28,20 +32,24 @@ struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional>
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto output =
mm->add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto output = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
return p;
}
......
......@@ -2,6 +2,10 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp>
struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
......@@ -23,24 +27,28 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto output =
mm->add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, output);
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto output = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
return p;
}
......
......@@ -2,6 +2,10 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/operators.hpp>
struct test_rnn_forward : verify_program<test_rnn_forward>
......@@ -28,20 +32,24 @@ struct test_rnn_forward : verify_program<test_rnn_forward>
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto hs =
mm->add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
auto lho = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_return({hs, lho});
return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment