Unverified Commit 8d21fdc9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor to use make_op almost everywhere (#696)

* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold
parent b5633c27
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_tan : verify_program<test_tan> struct test_tan : verify_program<test_tan>
{ {
...@@ -12,7 +12,7 @@ struct test_tan : verify_program<test_tan> ...@@ -12,7 +12,7 @@ struct test_tan : verify_program<test_tan>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::tan{}, x); mm->add_instruction(migraphx::make_op("tan"), x);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_tanh : verify_program<test_tanh> struct test_tanh : verify_program<test_tanh>
{ {
...@@ -11,7 +11,7 @@ struct test_tanh : verify_program<test_tanh> ...@@ -11,7 +11,7 @@ struct test_tanh : verify_program<test_tanh>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::op::tanh{}, x); mm->add_instruction(migraphx::make_op("tanh"), x);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_trans_abs : verify_program<test_trans_abs> struct test_trans_abs : verify_program<test_trans_abs>
{ {
...@@ -11,10 +11,10 @@ struct test_trans_abs : verify_program<test_trans_abs> ...@@ -11,10 +11,10 @@ struct test_trans_abs : verify_program<test_trans_abs>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x); auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
auto absx = mm->add_instruction(migraphx::op::abs{}, tx); auto absx = mm->add_instruction(migraphx::make_op("abs"), tx);
auto r = mm->add_instruction(migraphx::op::add{}, absx, absx); auto r = mm->add_instruction(migraphx::make_op("add"), absx, absx);
mm->add_instruction(migraphx::op::contiguous{}, r); mm->add_instruction(migraphx::make_op("contiguous"), r);
return p; return p;
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_trans_ret : verify_program<test_trans_ret> struct test_trans_ret : verify_program<test_trans_ret>
{ {
...@@ -11,7 +11,7 @@ struct test_trans_ret : verify_program<test_trans_ret> ...@@ -11,7 +11,7 @@ struct test_trans_ret : verify_program<test_trans_ret>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x); auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
mm->add_return({tx}); mm->add_return({tx});
return p; return p;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_trans_tanh : verify_program<test_trans_tanh> struct test_trans_tanh : verify_program<test_trans_tanh>
{ {
...@@ -11,10 +11,10 @@ struct test_trans_tanh : verify_program<test_trans_tanh> ...@@ -11,10 +11,10 @@ struct test_trans_tanh : verify_program<test_trans_tanh>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x); auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
auto tanhx = mm->add_instruction(migraphx::op::tanh{}, tx); auto tanhx = mm->add_instruction(migraphx::make_op("tanh"), tx);
auto r = mm->add_instruction(migraphx::op::add{}, tanhx, tanhx); auto r = mm->add_instruction(migraphx::make_op("add"), tanhx, tanhx);
mm->add_instruction(migraphx::op::contiguous{}, r); mm->add_instruction(migraphx::make_op("contiguous"), r);
return p; return p;
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_trans_tanh1 : verify_program<test_trans_tanh1> struct test_trans_tanh1 : verify_program<test_trans_tanh1>
{ {
...@@ -11,9 +11,9 @@ struct test_trans_tanh1 : verify_program<test_trans_tanh1> ...@@ -11,9 +11,9 @@ struct test_trans_tanh1 : verify_program<test_trans_tanh1>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x); auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
auto tanhx = mm->add_instruction(migraphx::op::tanh{}, tx); auto tanhx = mm->add_instruction(migraphx::make_op("tanh"), tx);
auto r = mm->add_instruction(migraphx::op::add{}, tanhx, tanhx); auto r = mm->add_instruction(migraphx::make_op("add"), tanhx, tanhx);
mm->add_return({tx, r}); mm->add_return({tx, r});
return p; return p;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_transpose : verify_program<test_transpose> struct test_transpose : verify_program<test_transpose>
{ {
...@@ -13,8 +13,8 @@ struct test_transpose : verify_program<test_transpose> ...@@ -13,8 +13,8 @@ struct test_transpose : verify_program<test_transpose>
migraphx::shape s{migraphx::shape::float_type, {4, 3, 4, 4}}; migraphx::shape s{migraphx::shape::float_type, {4, 3, 4, 4}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
std::vector<int64_t> perm = {0, 2, 3, 1}; std::vector<int64_t> perm = {0, 2, 3, 1};
auto l = mm->add_instruction(migraphx::op::transpose{perm}, x); auto l = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), x);
mm->add_instruction(migraphx::op::contiguous{}, l); mm->add_instruction(migraphx::make_op("contiguous"), l);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_triadd : verify_program<test_triadd> struct test_triadd : verify_program<test_triadd>
{ {
...@@ -14,8 +14,8 @@ struct test_triadd : verify_program<test_triadd> ...@@ -14,8 +14,8 @@ struct test_triadd : verify_program<test_triadd>
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s); auto z = mm->add_parameter("z", s);
auto sum = mm->add_instruction(migraphx::op::add{}, x, y); auto sum = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_instruction(migraphx::op::add{}, sum, z); mm->add_instruction(migraphx::make_op("add"), sum, z);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_triadd2 : verify_program<test_triadd2> struct test_triadd2 : verify_program<test_triadd2>
{ {
...@@ -12,12 +12,13 @@ struct test_triadd2 : verify_program<test_triadd2> ...@@ -12,12 +12,13 @@ struct test_triadd2 : verify_program<test_triadd2>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape b{migraphx::shape::float_type, {3}}; migraphx::shape b{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", b); auto z = mm->add_parameter("z", b);
auto zb = mm->add_instruction(migraphx::op::broadcast{1, s.lens()}, z); auto zb = mm->add_instruction(
auto sum = mm->add_instruction(migraphx::op::add{}, x, y); migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), z);
mm->add_instruction(migraphx::op::add{}, sum, zb); auto sum = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_instruction(migraphx::make_op("add"), sum, zb);
return p; return p;
} }
}; };
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
struct test_triadd_broadcast : verify_program<test_triadd_broadcast> struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
...@@ -12,12 +13,13 @@ struct test_triadd_broadcast : verify_program<test_triadd_broadcast> ...@@ -12,12 +13,13 @@ struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}}); auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 2}}); auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto z = mm->add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}}); auto z = mm->add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}});
auto by = mm->add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y); auto by = mm->add_instruction(
auto sum = mm->add_instruction(migraphx::op::add{}, x, by); migraphx::make_op("broadcast", {{"axis", 0}, {"dims", x->get_shape().lens()}}), y);
mm->add_instruction(migraphx::op::add{}, sum, z); auto sum = mm->add_instruction(migraphx::make_op("add"), x, by);
mm->add_instruction(migraphx::make_op("add"), sum, z);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_triadd_relu : verify_program<test_triadd_relu> struct test_triadd_relu : verify_program<test_triadd_relu>
{ {
...@@ -13,9 +13,9 @@ struct test_triadd_relu : verify_program<test_triadd_relu> ...@@ -13,9 +13,9 @@ struct test_triadd_relu : verify_program<test_triadd_relu>
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto sum = mm->add_instruction(migraphx::op::add{}, x, y); auto sum = mm->add_instruction(migraphx::make_op("add"), x, y);
auto triadd = mm->add_instruction(migraphx::op::add{}, sum, z); auto triadd = mm->add_instruction(migraphx::make_op("add"), sum, z);
mm->add_instruction(migraphx::op::relu{}, triadd); mm->add_instruction(migraphx::make_op("relu"), triadd);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_triadd_sigmoid : verify_program<test_triadd_sigmoid> struct test_triadd_sigmoid : verify_program<test_triadd_sigmoid>
{ {
...@@ -13,9 +13,9 @@ struct test_triadd_sigmoid : verify_program<test_triadd_sigmoid> ...@@ -13,9 +13,9 @@ struct test_triadd_sigmoid : verify_program<test_triadd_sigmoid>
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto sum = mm->add_instruction(migraphx::op::add{}, x, y); auto sum = mm->add_instruction(migraphx::make_op("add"), x, y);
auto triadd = mm->add_instruction(migraphx::op::add{}, sum, z); auto triadd = mm->add_instruction(migraphx::make_op("add"), sum, z);
mm->add_instruction(migraphx::op::sigmoid{}, triadd); mm->add_instruction(migraphx::make_op("sigmoid"), triadd);
return p; return p;
} }
}; };
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct test_triadd_tanh : verify_program<test_triadd_tanh> struct test_triadd_tanh : verify_program<test_triadd_tanh>
{ {
...@@ -13,9 +13,9 @@ struct test_triadd_tanh : verify_program<test_triadd_tanh> ...@@ -13,9 +13,9 @@ struct test_triadd_tanh : verify_program<test_triadd_tanh>
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto sum = mm->add_instruction(migraphx::op::add{}, x, y); auto sum = mm->add_instruction(migraphx::make_op("add"), x, y);
auto triadd = mm->add_instruction(migraphx::op::add{}, sum, z); auto triadd = mm->add_instruction(migraphx::make_op("add"), sum, z);
mm->add_instruction(migraphx::op::tanh{}, triadd); mm->add_instruction(migraphx::make_op("tanh"), triadd);
return p; return p;
} }
}; };
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_var_sl_gru_bidirct : verify_program<test_var_sl_gru_bidirct> struct test_var_sl_gru_bidirct : verify_program<test_var_sl_gru_bidirct>
...@@ -34,18 +38,22 @@ struct test_var_sl_gru_bidirct : verify_program<test_var_sl_gru_bidirct> ...@@ -34,18 +38,22 @@ struct test_var_sl_gru_bidirct : verify_program<test_var_sl_gru_bidirct>
std::vector<int> sl_data{2, 1, 3}; std::vector<int> sl_data{2, 1, 3};
auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data}); auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data});
auto hs = auto hs = mm->add_instruction(
mm->add_instruction(migraphx::op::gru{hidden_size, migraphx::make_op(
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, "gru",
migraphx::op::rnn_direction::bidirectional, {{"hidden_size", hidden_size},
clip}, {"actv_func",
seq, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
w, migraphx::make_op("tanh")})},
r, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
bias, {"clip", clip}}),
sql, seq,
ih); w,
auto lho = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs); r,
bias,
sql,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_return({hs, lho}); mm->add_return({hs, lho});
return p; return p;
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
struct test_var_sl_gru_forward : verify_program<test_var_sl_gru_forward> struct test_var_sl_gru_forward : verify_program<test_var_sl_gru_forward>
...@@ -34,18 +38,22 @@ struct test_var_sl_gru_forward : verify_program<test_var_sl_gru_forward> ...@@ -34,18 +38,22 @@ struct test_var_sl_gru_forward : verify_program<test_var_sl_gru_forward>
std::vector<int> sl_data{3, 2, 1}; std::vector<int> sl_data{3, 2, 1};
auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data}); auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data});
auto hs = auto hs = mm->add_instruction(
mm->add_instruction(migraphx::op::gru{hidden_size, migraphx::make_op(
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, "gru",
migraphx::op::rnn_direction::forward, {{"hidden_size", hidden_size},
clip}, {"actv_func",
seq, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
w, migraphx::make_op("tanh")})},
r, {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
bias, {"clip", clip}}),
sql, seq,
ih); w,
auto lho = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs); r,
bias,
sql,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_return({lho, hs}); mm->add_return({lho, hs});
return p; return p;
......
...@@ -348,6 +348,9 @@ bool has_finalize(const T& x) ...@@ -348,6 +348,9 @@ bool has_finalize(const T& x)
return detail::has_finalize_op(x); return detail::has_finalize_op(x);
} }
void migraphx_to_value(value& v, const operation& op);
void migraphx_from_value(const value& v, operation& op);
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment