Unverified Commit 8d21fdc9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor to use make_op almost everywhere (#696)

* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold
parent b5633c27
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_tan : verify_program<test_tan>
{
......@@ -12,7 +12,7 @@ struct test_tan : verify_program<test_tan>
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::tan{}, x);
mm->add_instruction(migraphx::make_op("tan"), x);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_tanh : verify_program<test_tanh>
{
......@@ -11,7 +11,7 @@ struct test_tanh : verify_program<test_tanh>
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::op::tanh{}, x);
mm->add_instruction(migraphx::make_op("tanh"), x);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_trans_abs : verify_program<test_trans_abs>
{
......@@ -11,10 +11,10 @@ struct test_trans_abs : verify_program<test_trans_abs>
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto absx = mm->add_instruction(migraphx::op::abs{}, tx);
auto r = mm->add_instruction(migraphx::op::add{}, absx, absx);
mm->add_instruction(migraphx::op::contiguous{}, r);
auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
auto absx = mm->add_instruction(migraphx::make_op("abs"), tx);
auto r = mm->add_instruction(migraphx::make_op("add"), absx, absx);
mm->add_instruction(migraphx::make_op("contiguous"), r);
return p;
}
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_trans_ret : verify_program<test_trans_ret>
{
......@@ -11,7 +11,7 @@ struct test_trans_ret : verify_program<test_trans_ret>
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
mm->add_return({tx});
return p;
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_trans_tanh : verify_program<test_trans_tanh>
{
......@@ -11,10 +11,10 @@ struct test_trans_tanh : verify_program<test_trans_tanh>
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = mm->add_instruction(migraphx::op::tanh{}, tx);
auto r = mm->add_instruction(migraphx::op::add{}, tanhx, tanhx);
mm->add_instruction(migraphx::op::contiguous{}, r);
auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
auto tanhx = mm->add_instruction(migraphx::make_op("tanh"), tx);
auto r = mm->add_instruction(migraphx::make_op("add"), tanhx, tanhx);
mm->add_instruction(migraphx::make_op("contiguous"), r);
return p;
}
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_trans_tanh1 : verify_program<test_trans_tanh1>
{
......@@ -11,9 +11,9 @@ struct test_trans_tanh1 : verify_program<test_trans_tanh1>
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = mm->add_instruction(migraphx::op::tanh{}, tx);
auto r = mm->add_instruction(migraphx::op::add{}, tanhx, tanhx);
auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
auto tanhx = mm->add_instruction(migraphx::make_op("tanh"), tx);
auto r = mm->add_instruction(migraphx::make_op("add"), tanhx, tanhx);
mm->add_return({tx, r});
return p;
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_transpose : verify_program<test_transpose>
{
......@@ -13,8 +13,8 @@ struct test_transpose : verify_program<test_transpose>
migraphx::shape s{migraphx::shape::float_type, {4, 3, 4, 4}};
auto x = mm->add_parameter("x", s);
std::vector<int64_t> perm = {0, 2, 3, 1};
auto l = mm->add_instruction(migraphx::op::transpose{perm}, x);
mm->add_instruction(migraphx::op::contiguous{}, l);
auto l = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), x);
mm->add_instruction(migraphx::make_op("contiguous"), l);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_triadd : verify_program<test_triadd>
{
......@@ -14,8 +14,8 @@ struct test_triadd : verify_program<test_triadd>
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto sum = mm->add_instruction(migraphx::op::add{}, x, y);
mm->add_instruction(migraphx::op::add{}, sum, z);
auto sum = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_instruction(migraphx::make_op("add"), sum, z);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_triadd2 : verify_program<test_triadd2>
{
......@@ -12,12 +12,13 @@ struct test_triadd2 : verify_program<test_triadd2>
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape b{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", b);
auto zb = mm->add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto sum = mm->add_instruction(migraphx::op::add{}, x, y);
mm->add_instruction(migraphx::op::add{}, sum, zb);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", b);
auto zb = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), z);
auto sum = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_instruction(migraphx::make_op("add"), sum, zb);
return p;
}
};
......@@ -2,7 +2,8 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
......@@ -12,12 +13,13 @@ struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto z = mm->add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}});
auto by = mm->add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y);
auto sum = mm->add_instruction(migraphx::op::add{}, x, by);
mm->add_instruction(migraphx::op::add{}, sum, z);
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto z = mm->add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}});
auto by = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", x->get_shape().lens()}}), y);
auto sum = mm->add_instruction(migraphx::make_op("add"), x, by);
mm->add_instruction(migraphx::make_op("add"), sum, z);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_triadd_relu : verify_program<test_triadd_relu>
{
......@@ -13,9 +13,9 @@ struct test_triadd_relu : verify_program<test_triadd_relu>
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto sum = mm->add_instruction(migraphx::op::add{}, x, y);
auto triadd = mm->add_instruction(migraphx::op::add{}, sum, z);
mm->add_instruction(migraphx::op::relu{}, triadd);
auto sum = mm->add_instruction(migraphx::make_op("add"), x, y);
auto triadd = mm->add_instruction(migraphx::make_op("add"), sum, z);
mm->add_instruction(migraphx::make_op("relu"), triadd);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_triadd_sigmoid : verify_program<test_triadd_sigmoid>
{
......@@ -13,9 +13,9 @@ struct test_triadd_sigmoid : verify_program<test_triadd_sigmoid>
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto sum = mm->add_instruction(migraphx::op::add{}, x, y);
auto triadd = mm->add_instruction(migraphx::op::add{}, sum, z);
mm->add_instruction(migraphx::op::sigmoid{}, triadd);
auto sum = mm->add_instruction(migraphx::make_op("add"), x, y);
auto triadd = mm->add_instruction(migraphx::make_op("add"), sum, z);
mm->add_instruction(migraphx::make_op("sigmoid"), triadd);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct test_triadd_tanh : verify_program<test_triadd_tanh>
{
......@@ -13,9 +13,9 @@ struct test_triadd_tanh : verify_program<test_triadd_tanh>
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto sum = mm->add_instruction(migraphx::op::add{}, x, y);
auto triadd = mm->add_instruction(migraphx::op::add{}, sum, z);
mm->add_instruction(migraphx::op::tanh{}, triadd);
auto sum = mm->add_instruction(migraphx::make_op("add"), x, y);
auto triadd = mm->add_instruction(migraphx::make_op("add"), sum, z);
mm->add_instruction(migraphx::make_op("tanh"), triadd);
return p;
}
};
......@@ -2,6 +2,10 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp>
struct test_var_sl_gru_bidirct : verify_program<test_var_sl_gru_bidirct>
......@@ -34,18 +38,22 @@ struct test_var_sl_gru_bidirct : verify_program<test_var_sl_gru_bidirct>
std::vector<int> sl_data{2, 1, 3};
auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data});
auto hs =
mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
sql,
ih);
auto lho = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
auto hs = mm->add_instruction(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r,
bias,
sql,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_return({hs, lho});
return p;
......
......@@ -2,6 +2,10 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operators.hpp>
struct test_var_sl_gru_forward : verify_program<test_var_sl_gru_forward>
......@@ -34,18 +38,22 @@ struct test_var_sl_gru_forward : verify_program<test_var_sl_gru_forward>
std::vector<int> sl_data{3, 2, 1};
auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data});
auto hs =
mm->add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
sql,
ih);
auto lho = mm->add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
auto hs = mm->add_instruction(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
sql,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_return({lho, hs});
return p;
......
......@@ -348,6 +348,9 @@ bool has_finalize(const T& x)
return detail::has_finalize_op(x);
}
void migraphx_to_value(value& v, const operation& op);
void migraphx_from_value(const value& v, operation& op);
#endif
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment