Unverified Commit 8d21fdc9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor to use make_op almost everywhere (#696)

* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold
parent b5633c27
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
...@@ -46,7 +46,7 @@ TEST_CASE(after_literal_transpose) ...@@ -46,7 +46,7 @@ TEST_CASE(after_literal_transpose)
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
EXPECT(p.get_output_shapes().back().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
mm->add_instruction(pass_op{}, t); mm->add_instruction(pass_op{}, t);
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed()); EXPECT(p.get_output_shapes().back().transposed());
...@@ -64,7 +64,8 @@ TEST_CASE(after_literal_broadcast) ...@@ -64,7 +64,8 @@ TEST_CASE(after_literal_broadcast)
auto l2 = mm->add_literal(get_2()); auto l2 = mm->add_literal(get_2());
EXPECT(p.get_output_shapes().back().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().broadcasted()); EXPECT(not p.get_output_shapes().back().broadcasted());
auto b = mm->add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2); auto b = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2);
mm->add_instruction(pass_op{}, b); mm->add_instruction(pass_op{}, b);
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().broadcasted()); EXPECT(p.get_output_shapes().back().broadcasted());
...@@ -81,7 +82,7 @@ TEST_CASE(after_param_transpose) ...@@ -81,7 +82,7 @@ TEST_CASE(after_param_transpose)
auto l = mm->add_parameter("2x2", {migraphx::shape::float_type, {2, 2}}); auto l = mm->add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
EXPECT(p.get_output_shapes().back().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
mm->add_instruction(pass_op{}, t); mm->add_instruction(pass_op{}, t);
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed()); EXPECT(p.get_output_shapes().back().transposed());
...@@ -99,7 +100,8 @@ TEST_CASE(after_param_broadcast) ...@@ -99,7 +100,8 @@ TEST_CASE(after_param_broadcast)
auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {2}}); auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {2}});
EXPECT(p.get_output_shapes().back().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().broadcasted()); EXPECT(not p.get_output_shapes().back().broadcasted());
auto b = mm->add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2); auto b = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2);
mm->add_instruction(pass_op{}, b); mm->add_instruction(pass_op{}, b);
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().broadcasted()); EXPECT(p.get_output_shapes().back().broadcasted());
......
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/op/undefined.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/identity.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
...@@ -116,7 +115,7 @@ TEST_CASE(undefined_test) ...@@ -116,7 +115,7 @@ TEST_CASE(undefined_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto undef = mm->add_instruction(migraphx::op::undefined{}); auto undef = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -133,8 +132,8 @@ TEST_CASE(duplicate_args1) ...@@ -133,8 +132,8 @@ TEST_CASE(duplicate_args1)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_literal(0); auto l0 = mm->add_literal(0);
auto l3 = mm->add_literal(3); auto l3 = mm->add_literal(3);
mm->add_instruction(migraphx::op::add{}, l3, l3); mm->add_instruction(migraphx::make_op("add"), l3, l3);
mm->add_instruction(migraphx::op::identity{}, l0); mm->add_instruction(migraphx::make_op("identity"), l0);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count); EXPECT(std::distance(p.begin(), p.end()) != count);
...@@ -149,9 +148,9 @@ TEST_CASE(duplicate_args2) ...@@ -149,9 +148,9 @@ TEST_CASE(duplicate_args2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_literal(0); auto l0 = mm->add_literal(0);
auto l3 = mm->add_literal(3); auto l3 = mm->add_literal(3);
auto sum1 = mm->add_instruction(migraphx::op::add{}, l0, l3); auto sum1 = mm->add_instruction(migraphx::make_op("add"), l0, l3);
mm->add_instruction(migraphx::op::add{}, sum1, l3); mm->add_instruction(migraphx::make_op("add"), sum1, l3);
mm->add_instruction(migraphx::op::identity{}, l0); mm->add_instruction(migraphx::make_op("identity"), l0);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count); EXPECT(std::distance(p.begin(), p.end()) != count);
...@@ -166,10 +165,10 @@ TEST_CASE(duplicate_args3) ...@@ -166,10 +165,10 @@ TEST_CASE(duplicate_args3)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_literal(0); auto l0 = mm->add_literal(0);
auto l3 = mm->add_literal(3); auto l3 = mm->add_literal(3);
auto sum1 = mm->add_instruction(migraphx::op::add{}, l0, l3); auto sum1 = mm->add_instruction(migraphx::make_op("add"), l0, l3);
auto sum2 = mm->add_instruction(migraphx::op::add{}, l0, sum1); auto sum2 = mm->add_instruction(migraphx::make_op("add"), l0, sum1);
mm->add_instruction(migraphx::op::add{}, sum2, l3); mm->add_instruction(migraphx::make_op("add"), sum2, l3);
mm->add_instruction(migraphx::op::identity{}, l0); mm->add_instruction(migraphx::make_op("identity"), l0);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count); EXPECT(std::distance(p.begin(), p.end()) != count);
......
#include <migraphx/decompose.hpp> #include <migraphx/decompose.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
...@@ -21,8 +18,8 @@ TEST_CASE(dot_add) ...@@ -21,8 +18,8 @@ TEST_CASE(dot_add)
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = mm1->add_instruction(migraphx::op::dot{}, x, y, z); auto dot = mm1->add_instruction(migraphx::make_op("dot"), x, y, z);
mm1->add_instruction(migraphx::op::identity{}, dot); mm1->add_instruction(migraphx::make_op("identity"), dot);
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
...@@ -31,9 +28,10 @@ TEST_CASE(dot_add) ...@@ -31,9 +28,10 @@ TEST_CASE(dot_add)
auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = mm2->add_instruction(migraphx::op::dot{1, 0}, x, y); auto dot =
auto add = mm2->add_instruction(migraphx::op::add{}, dot, z); mm2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
mm2->add_instruction(migraphx::op::identity{}, add); auto add = mm2->add_instruction(migraphx::make_op("add"), dot, z);
mm2->add_instruction(migraphx::make_op("identity"), add);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -46,8 +44,9 @@ TEST_CASE(dot_add_beta_float) ...@@ -46,8 +44,9 @@ TEST_CASE(dot_add_beta_float)
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = mm1->add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z); auto dot = mm1->add_instruction(
mm1->add_instruction(migraphx::op::identity{}, dot); migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
mm1->add_instruction(migraphx::make_op("identity"), dot);
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
...@@ -56,13 +55,15 @@ TEST_CASE(dot_add_beta_float) ...@@ -56,13 +55,15 @@ TEST_CASE(dot_add_beta_float)
auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = mm2->add_instruction(migraphx::op::dot{1, 0}, x, y); auto dot =
mm2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta = mm2->add_literal( auto beta = mm2->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto beta_broadcast = mm2->add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta); auto beta_broadcast = mm2->add_instruction(
auto mul = mm2->add_instruction(migraphx::op::mul{}, z, beta_broadcast); migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto add = mm2->add_instruction(migraphx::op::add{}, dot, mul); auto mul = mm2->add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
mm2->add_instruction(migraphx::op::identity{}, add); auto add = mm2->add_instruction(migraphx::make_op("add"), dot, mul);
mm2->add_instruction(migraphx::make_op("identity"), add);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -75,8 +76,9 @@ TEST_CASE(dot_add_beta_half) ...@@ -75,8 +76,9 @@ TEST_CASE(dot_add_beta_half)
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = mm1->add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z); auto dot = mm1->add_instruction(
mm1->add_instruction(migraphx::op::identity{}, dot); migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
mm1->add_instruction(migraphx::make_op("identity"), dot);
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
...@@ -85,13 +87,15 @@ TEST_CASE(dot_add_beta_half) ...@@ -85,13 +87,15 @@ TEST_CASE(dot_add_beta_half)
auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = mm2->add_instruction(migraphx::op::dot{1, 0}, x, y); auto dot =
mm2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta = auto beta =
mm2->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}}); mm2->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
auto beta_broadcast = mm2->add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta); auto beta_broadcast = mm2->add_instruction(
auto mul = mm2->add_instruction(migraphx::op::mul{}, z, beta_broadcast); migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto add = mm2->add_instruction(migraphx::op::add{}, dot, mul); auto mul = mm2->add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
mm2->add_instruction(migraphx::op::identity{}, add); auto add = mm2->add_instruction(migraphx::make_op("add"), dot, mul);
mm2->add_instruction(migraphx::make_op("identity"), add);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -104,8 +108,9 @@ TEST_CASE(dot_add_beta_double) ...@@ -104,8 +108,9 @@ TEST_CASE(dot_add_beta_double)
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot = mm1->add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z); auto dot = mm1->add_instruction(
mm1->add_instruction(migraphx::op::identity{}, dot); migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
mm1->add_instruction(migraphx::make_op("identity"), dot);
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
...@@ -114,13 +119,15 @@ TEST_CASE(dot_add_beta_double) ...@@ -114,13 +119,15 @@ TEST_CASE(dot_add_beta_double)
auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot = mm2->add_instruction(migraphx::op::dot{1, 0}, x, y); auto dot =
mm2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta = mm2->add_literal( auto beta = mm2->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}}); migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
auto beta_broadcast = mm2->add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta); auto beta_broadcast = mm2->add_instruction(
auto mul = mm2->add_instruction(migraphx::op::mul{}, z, beta_broadcast); migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto add = mm2->add_instruction(migraphx::op::add{}, dot, mul); auto mul = mm2->add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
mm2->add_instruction(migraphx::op::identity{}, add); auto add = mm2->add_instruction(migraphx::make_op("add"), dot, mul);
mm2->add_instruction(migraphx::make_op("identity"), add);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -133,8 +140,9 @@ TEST_CASE(dot_add_beta_int) ...@@ -133,8 +140,9 @@ TEST_CASE(dot_add_beta_int)
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = mm1->add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z); auto dot = mm1->add_instruction(
mm1->add_instruction(migraphx::op::identity{}, dot); migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
mm1->add_instruction(migraphx::make_op("identity"), dot);
} }
migraphx::program p2 = p1; migraphx::program p2 = p1;
run_pass(p1); run_pass(p1);
......
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/add.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
...@@ -19,9 +20,9 @@ TEST_CASE(cse_test1) ...@@ -19,9 +20,9 @@ TEST_CASE(cse_test1)
auto* mm1 = p1.get_main_module(); auto* mm1 = p1.get_main_module();
auto one = mm1->add_literal(1); auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2); auto two = mm1->add_literal(2);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, two); auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, two);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, one, two); auto sum2 = mm1->add_instruction(migraphx::make_op("add"), one, two);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2); auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
mm1->add_instruction(pass_op{}, sum3); mm1->add_instruction(pass_op{}, sum3);
} }
run_pass(p1); run_pass(p1);
...@@ -31,8 +32,8 @@ TEST_CASE(cse_test1) ...@@ -31,8 +32,8 @@ TEST_CASE(cse_test1)
auto* mm2 = p2.get_main_module(); auto* mm2 = p2.get_main_module();
auto one = mm2->add_literal(1); auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2); auto two = mm2->add_literal(2);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two); auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum1); auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum1);
mm2->add_instruction(pass_op{}, sum3); mm2->add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -45,9 +46,9 @@ TEST_CASE(cse_test2) ...@@ -45,9 +46,9 @@ TEST_CASE(cse_test2)
auto* mm1 = p1.get_main_module(); auto* mm1 = p1.get_main_module();
auto one = mm1->add_literal(1); auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2); auto two = mm1->add_literal(2);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, two); auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, two);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, two, one); auto sum2 = mm1->add_instruction(migraphx::make_op("add"), two, one);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2); auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
mm1->add_instruction(pass_op{}, sum3); mm1->add_instruction(pass_op{}, sum3);
} }
run_pass(p1); run_pass(p1);
...@@ -57,9 +58,9 @@ TEST_CASE(cse_test2) ...@@ -57,9 +58,9 @@ TEST_CASE(cse_test2)
auto* mm2 = p2.get_main_module(); auto* mm2 = p2.get_main_module();
auto one = mm2->add_literal(1); auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2); auto two = mm2->add_literal(2);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two); auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, two, one); auto sum2 = mm2->add_instruction(migraphx::make_op("add"), two, one);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum2); auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum2);
mm2->add_instruction(pass_op{}, sum3); mm2->add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -72,9 +73,9 @@ TEST_CASE(cse_test3) ...@@ -72,9 +73,9 @@ TEST_CASE(cse_test3)
auto* mm1 = p1.get_main_module(); auto* mm1 = p1.get_main_module();
auto one = mm1->add_literal(1); auto one = mm1->add_literal(1);
auto two = mm1->add_literal(1); auto two = mm1->add_literal(1);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, two); auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, two);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, two, one); auto sum2 = mm1->add_instruction(migraphx::make_op("add"), two, one);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2); auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
mm1->add_instruction(pass_op{}, sum3); mm1->add_instruction(pass_op{}, sum3);
} }
run_pass(p1); run_pass(p1);
...@@ -83,8 +84,8 @@ TEST_CASE(cse_test3) ...@@ -83,8 +84,8 @@ TEST_CASE(cse_test3)
{ {
auto* mm2 = p2.get_main_module(); auto* mm2 = p2.get_main_module();
auto one = mm2->add_literal(1); auto one = mm2->add_literal(1);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, one); auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, one);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum1); auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum1);
mm2->add_instruction(pass_op{}, sum3); mm2->add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -97,11 +98,11 @@ TEST_CASE(cse_test4) ...@@ -97,11 +98,11 @@ TEST_CASE(cse_test4)
auto* mm1 = p1.get_main_module(); auto* mm1 = p1.get_main_module();
auto one = mm1->add_literal(1); auto one = mm1->add_literal(1);
auto two = mm1->add_literal(1); auto two = mm1->add_literal(1);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, two); auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, two);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, two, one); auto sum2 = mm1->add_instruction(migraphx::make_op("add"), two, one);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, one); auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, one);
auto sum4 = mm1->add_instruction(migraphx::op::add{}, sum2, two); auto sum4 = mm1->add_instruction(migraphx::make_op("add"), sum2, two);
auto sum5 = mm1->add_instruction(migraphx::op::add{}, sum4, sum3); auto sum5 = mm1->add_instruction(migraphx::make_op("add"), sum4, sum3);
mm1->add_instruction(pass_op{}, sum5); mm1->add_instruction(pass_op{}, sum5);
} }
run_pass(p1); run_pass(p1);
...@@ -110,9 +111,9 @@ TEST_CASE(cse_test4) ...@@ -110,9 +111,9 @@ TEST_CASE(cse_test4)
{ {
auto* mm2 = p2.get_main_module(); auto* mm2 = p2.get_main_module();
auto one = mm2->add_literal(1); auto one = mm2->add_literal(1);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, one); auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, one);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, one); auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum1, one);
auto sum5 = mm2->add_instruction(migraphx::op::add{}, sum3, sum3); auto sum5 = mm2->add_instruction(migraphx::make_op("add"), sum3, sum3);
mm2->add_instruction(pass_op{}, sum5); mm2->add_instruction(pass_op{}, sum5);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -130,11 +131,11 @@ TEST_CASE(cse_test_literal) ...@@ -130,11 +131,11 @@ TEST_CASE(cse_test_literal)
auto six3 = mm1->add_literal(6); auto six3 = mm1->add_literal(6);
auto zero3 = mm1->add_literal(0); auto zero3 = mm1->add_literal(0);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, six1, zero1); auto sum1 = mm1->add_instruction(migraphx::make_op("add"), six1, zero1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, six2, zero2); auto sum2 = mm1->add_instruction(migraphx::make_op("add"), six2, zero2);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, six3, zero3); auto sum3 = mm1->add_instruction(migraphx::make_op("add"), six3, zero3);
auto sum4 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2); auto sum4 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
auto sum5 = mm1->add_instruction(migraphx::op::add{}, sum3, sum4); auto sum5 = mm1->add_instruction(migraphx::make_op("add"), sum3, sum4);
mm1->add_instruction(pass_op{}, sum5); mm1->add_instruction(pass_op{}, sum5);
} }
run_pass(p1); run_pass(p1);
...@@ -144,9 +145,9 @@ TEST_CASE(cse_test_literal) ...@@ -144,9 +145,9 @@ TEST_CASE(cse_test_literal)
auto* mm2 = p2.get_main_module(); auto* mm2 = p2.get_main_module();
auto six = mm2->add_literal(6); auto six = mm2->add_literal(6);
auto zero = mm2->add_literal(0); auto zero = mm2->add_literal(0);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, six, zero); auto sum1 = mm2->add_instruction(migraphx::make_op("add"), six, zero);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, sum1, sum1); auto sum2 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum1);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum2); auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum2);
mm2->add_instruction(pass_op{}, sum3); mm2->add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
......
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/sin.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
...@@ -22,8 +18,8 @@ TEST_CASE(standard_op) ...@@ -22,8 +18,8 @@ TEST_CASE(standard_op)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}}); auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_instruction(pass_standard_op{}, c); mm->add_instruction(pass_standard_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -36,8 +32,8 @@ TEST_CASE(standard_op_const) ...@@ -36,8 +32,8 @@ TEST_CASE(standard_op_const)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_instruction(pass_standard_op{}, c); mm->add_instruction(pass_standard_op{}, c);
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 2);
...@@ -49,8 +45,8 @@ TEST_CASE(non_standard_op) ...@@ -49,8 +45,8 @@ TEST_CASE(non_standard_op)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}}); auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_instruction(pass_op{}, c); mm->add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -63,8 +59,8 @@ TEST_CASE(non_standard_op_const) ...@@ -63,8 +59,8 @@ TEST_CASE(non_standard_op_const)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_instruction(pass_op{}, c); mm->add_instruction(pass_op{}, c);
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 2);
...@@ -76,10 +72,10 @@ TEST_CASE(transpose_gemm) ...@@ -76,10 +72,10 @@ TEST_CASE(transpose_gemm)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
auto ic = mm->add_instruction(migraphx::op::identity{}, c); auto ic = mm->add_instruction(migraphx::make_op("identity"), c);
mm->add_instruction(migraphx::op::dot{}, ic, l); mm->add_instruction(migraphx::make_op("dot"), ic, l);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 1)); EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
...@@ -91,9 +87,9 @@ TEST_CASE(transpose_standard_op) ...@@ -91,9 +87,9 @@ TEST_CASE(transpose_standard_op)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}}); auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
auto sn = mm->add_instruction(migraphx::op::sin{}, c); auto sn = mm->add_instruction(migraphx::make_op("sin"), c);
mm->add_instruction(pass_standard_op{}, sn); mm->add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -106,9 +102,9 @@ TEST_CASE(transpose_standard_op_const) ...@@ -106,9 +102,9 @@ TEST_CASE(transpose_standard_op_const)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
auto sn = mm->add_instruction(migraphx::op::sin{}, c); auto sn = mm->add_instruction(migraphx::make_op("sin"), c);
mm->add_instruction(pass_standard_op{}, sn); mm->add_instruction(pass_standard_op{}, sn);
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 3); EXPECT(std::distance(p.begin(), p.end()) == 3);
...@@ -120,9 +116,10 @@ TEST_CASE(no_packed_unary_op) ...@@ -120,9 +116,10 @@ TEST_CASE(no_packed_unary_op)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t = mm->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l); auto t = mm->add_instruction(
auto c = mm->add_instruction(migraphx::op::contiguous{}, t); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), l);
auto sn = mm->add_instruction(migraphx::op::sin{}, c); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
auto sn = mm->add_instruction(migraphx::make_op("sin"), c);
mm->add_instruction(pass_standard_op{}, sn); mm->add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
...@@ -135,8 +132,8 @@ TEST_CASE(non_standard_return_input) ...@@ -135,8 +132,8 @@ TEST_CASE(non_standard_return_input)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto tl = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l); auto tl = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::op::contiguous{}, tl); auto c = mm->add_instruction(migraphx::make_op("contiguous"), tl);
mm->add_return({c}); mm->add_return({c});
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
...@@ -18,9 +19,9 @@ TEST_CASE(simple_test) ...@@ -18,9 +19,9 @@ TEST_CASE(simple_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto one_identity = mm->add_instruction(migraphx::op::identity{}, one); auto one_identity = mm->add_instruction(migraphx::make_op("identity"), one);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto two_identity = mm->add_instruction(migraphx::op::identity{}, two); auto two_identity = mm->add_instruction(migraphx::make_op("identity"), two);
mm->add_instruction(sum_op{}, one_identity, two_identity); mm->add_instruction(sum_op{}, one_identity, two_identity);
run_pass(p); run_pass(p);
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
...@@ -39,7 +40,7 @@ TEST_CASE(simple_test_end) ...@@ -39,7 +40,7 @@ TEST_CASE(simple_test_end)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto ans = mm->add_instruction(sum_op{}, one, two); auto ans = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::op::identity{}, ans); mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p); run_pass(p);
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
...@@ -59,7 +60,7 @@ TEST_CASE(simple_test_end_dependency) ...@@ -59,7 +60,7 @@ TEST_CASE(simple_test_end_dependency)
auto three = mm->add_literal(3.0); auto three = mm->add_literal(3.0);
auto ans = mm->add_instruction(sum_op{}, one, two); auto ans = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, ans, three); mm->add_instruction(sum_op{}, ans, three);
mm->add_instruction(migraphx::op::identity{}, ans); mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p); run_pass(p);
EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
...@@ -20,7 +22,7 @@ create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::progr ...@@ -20,7 +22,7 @@ create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::progr
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights}); auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights});
return mm->add_instruction(migraphx::op::im2col{}, l_img, l_weights); return mm->add_instruction(migraphx::make_op("im2col"), l_img, l_weights);
} }
migraphx::instruction_ref migraphx::instruction_ref
...@@ -48,13 +50,14 @@ TEST_CASE(rewrite_pad) ...@@ -48,13 +50,14 @@ TEST_CASE(rewrite_pad)
std::iota(input.begin(), input.end(), 0); std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}}; migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = mm->add_literal(migraphx::literal{s_img, input}); auto l_img = mm->add_literal(migraphx::literal{s_img, input});
auto padded_img = mm->add_instruction(migraphx::op::pad{{0, 0, 1, 1, 0, 0, 1, 1}}, l_img); auto padded_img =
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 1, 1, 0, 0, 1, 1}}}), l_img);
auto l0 = create_im2col(padded_img, channels, p); auto l0 = create_im2col(padded_img, channels, p);
auto l1 = create_conv(padded_img, channels, p); auto l1 = create_conv(padded_img, channels, p);
auto l2 = mm->add_instruction(migraphx::op::pooling{"max"}, padded_img); auto l2 = mm->add_instruction(migraphx::make_op("pooling", {{"mode", "max"}}), padded_img);
mm->add_instruction(migraphx::op::identity{}, l0, l1, l2); mm->add_instruction(migraphx::make_op("identity"), l0, l1, l2);
auto s0 = l0->get_shape(); auto s0 = l0->get_shape();
auto s1 = l1->get_shape(); auto s1 = l1->get_shape();
...@@ -86,8 +89,9 @@ TEST_CASE(rewrite_pad_im2col_asymmetric) ...@@ -86,8 +89,9 @@ TEST_CASE(rewrite_pad_im2col_asymmetric)
std::iota(input.begin(), input.end(), 0); std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}}; migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = mm->add_literal(migraphx::literal{s_img, input}); auto l_img = mm->add_literal(migraphx::literal{s_img, input});
auto padded_img = mm->add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 2, 2}}, l_img); auto padded_img =
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 2, 2}}}), l_img);
auto l0 = create_im2col(padded_img, channels, p); auto l0 = create_im2col(padded_img, channels, p);
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -184,7 +184,7 @@ TEST_CASE(check_to_value2) ...@@ -184,7 +184,7 @@ TEST_CASE(check_to_value2)
{ {
migraphx::operation op = simple_operation{}; migraphx::operation op = simple_operation{};
auto v = migraphx::to_value(op); auto v = migraphx::to_value(op);
EXPECT(v == migraphx::value{{"data", 1}}); EXPECT(v == migraphx::value{{"name", "simple"}, {"operator", {{"data", 1}}}});
} }
TEST_CASE(check_from_value1) TEST_CASE(check_from_value1)
......
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
#include <migraphx/op/rnn_variable_seq_lens.hpp> #include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(load_op) TEST_CASE(load_op)
...@@ -25,18 +29,30 @@ TEST_CASE(make_op) ...@@ -25,18 +29,30 @@ TEST_CASE(make_op)
} }
} }
TEST_CASE(save_op)
{
for(const auto& name : migraphx::get_operators())
{
auto op1 = migraphx::load_op(name);
auto v = migraphx::to_value(op1);
auto op2 = migraphx::from_value<migraphx::operation>(v);
CHECK(op1 == op2);
}
}
TEST_CASE(make_op_from_value1) TEST_CASE(make_op_from_value1)
{ {
migraphx::operation x = migraphx::make_op( migraphx::operation x = migraphx::make_op(
"convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {2, 2}}}); "convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {2, 2}}});
migraphx::operation y = migraphx::op::convolution{{1, 1}, {2, 2}, {2, 2}}; migraphx::operation y = migraphx::make_op(
"convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {2, 2}}});
EXPECT(x == y); EXPECT(x == y);
} }
TEST_CASE(make_op_from_value2) TEST_CASE(make_op_from_value2)
{ {
migraphx::operation x = migraphx::make_op("convolution", {{"padding", {1, 1}}}); migraphx::operation x = migraphx::make_op("convolution", {{"padding", {1, 1}}});
migraphx::operation y = migraphx::op::convolution{{1, 1}}; migraphx::operation y = migraphx::make_op("convolution", {{"padding", {1, 1}}});
EXPECT(x == y); EXPECT(x == y);
} }
...@@ -45,7 +61,9 @@ TEST_CASE(make_rnn_op_from_value) ...@@ -45,7 +61,9 @@ TEST_CASE(make_rnn_op_from_value)
migraphx::op::rnn_direction dirct = migraphx::op::rnn_direction::reverse; migraphx::op::rnn_direction dirct = migraphx::op::rnn_direction::reverse;
migraphx::operation x = migraphx::make_op( migraphx::operation x = migraphx::make_op(
"rnn_var_sl_shift_output", {{"output_name", "hidden_states"}, {"direction", dirct}}); "rnn_var_sl_shift_output", {{"output_name", "hidden_states"}, {"direction", dirct}});
migraphx::operation y = migraphx::op::rnn_var_sl_shift_output{"hidden_states", dirct}; migraphx::operation y = migraphx::make_op(
"rnn_var_sl_shift_output",
{{"output_name", "hidden_states"}, {"direction", migraphx::to_value(dirct)}});
EXPECT(x == y); EXPECT(x == y);
} }
......
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(perf_report) TEST_CASE(perf_report)
...@@ -12,7 +13,7 @@ TEST_CASE(perf_report) ...@@ -12,7 +13,7 @@ TEST_CASE(perf_report)
std::stringstream ss; std::stringstream ss;
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(migraphx::op::add{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
p.perf_report(ss, 2, {}); p.perf_report(ss, 2, {});
......
...@@ -2,12 +2,11 @@ ...@@ -2,12 +2,11 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
#include <migraphx/make_op.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
migraphx::program create_program() migraphx::program create_program()
...@@ -69,8 +68,8 @@ TEST_CASE(program_copy) ...@@ -69,8 +68,8 @@ TEST_CASE(program_copy)
auto l2 = mm->add_literal(migraphx::literal(s, data)); auto l2 = mm->add_literal(migraphx::literal(s, data));
auto p1 = mm->add_parameter("x", s); auto p1 = mm->add_parameter("x", s);
auto po = mm->add_outline(s); auto po = mm->add_outline(s);
auto sum = mm->add_instruction(migraphx::op::add{}, l2, po); auto sum = mm->add_instruction(migraphx::make_op("add"), l2, po);
mm->add_instruction(migraphx::op::mul{}, sum, p1); mm->add_instruction(migraphx::make_op("mul"), sum, p1);
return p; return p;
}; };
...@@ -125,7 +124,8 @@ TEST_CASE(program_copy) ...@@ -125,7 +124,8 @@ TEST_CASE(program_copy)
auto para1 = mm1->add_parameter("m1", s1); auto para1 = mm1->add_parameter("m1", s1);
auto para2 = mm1->add_parameter("m2", s2); auto para2 = mm1->add_parameter("m2", s2);
auto para3 = mm1->add_parameter("m3", s3); auto para3 = mm1->add_parameter("m3", s3);
mm1->add_instruction(migraphx::op::dot{0.31f, 0.28f}, para1, para2, para3); mm1->add_instruction(
migraphx::make_op("dot", {{"alpha", 0.31f}, {"beta", 0.28f}}), para1, para2, para3);
migraphx::program p2{}; migraphx::program p2{};
p2 = p1; p2 = p1;
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment