Unverified Commit 0279acec authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor pass tests (#736)

* Update test for passes

* Formatting

* Rewrite simplify_reshapes

* Formatting

* Rewrite normalize pass

* Formatting

* Rewrite pooling

* Formatting

* Rewrite schedule tests

* Formatting
parent b889d472
...@@ -6,108 +6,99 @@ ...@@ -6,108 +6,99 @@
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::auto_contiguous{}}); }
{
migraphx::run_passes(*p.get_main_module(), {migraphx::auto_contiguous{}});
}
// TODO: Add this test case // TODO: Add this test case
void literal_broadcast() void literal_broadcast()
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); m.add_literal(get_2_broadcasted());
mm->add_literal(get_2_broadcasted()); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().broadcasted());
EXPECT(p.get_output_shapes().back().broadcasted()); run_pass(m);
run_pass(p); EXPECT(m.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().broadcasted());
EXPECT(not p.get_output_shapes().back().broadcasted());
} }
TEST_CASE(literal_transpose) TEST_CASE(literal_transpose)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); m.add_literal(get_2x2_transposed());
mm->add_literal(get_2x2_transposed()); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().transposed());
EXPECT(p.get_output_shapes().back().transposed()); run_pass(m);
run_pass(p); EXPECT(m.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().transposed());
EXPECT(not p.get_output_shapes().back().transposed());
} }
TEST_CASE(after_literal_transpose) TEST_CASE(after_literal_transpose)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto l = m.add_literal(get_2x2());
auto l = mm->add_literal(get_2x2()); EXPECT(m.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().transposed());
EXPECT(not p.get_output_shapes().back().transposed()); auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); m.add_instruction(pass_op{}, t);
mm->add_instruction(pass_op{}, t); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().transposed());
EXPECT(p.get_output_shapes().back().transposed()); run_pass(m);
run_pass(p); EXPECT(m.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().transposed());
EXPECT(not p.get_output_shapes().back().transposed());
} }
TEST_CASE(after_literal_broadcast) TEST_CASE(after_literal_broadcast)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto l1 = m.add_literal(get_2x2());
auto l1 = mm->add_literal(get_2x2()); auto l2 = m.add_literal(get_2());
auto l2 = mm->add_literal(get_2()); EXPECT(m.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().broadcasted());
EXPECT(not p.get_output_shapes().back().broadcasted()); auto b = m.add_instruction(
auto b = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2);
mm->add_instruction(pass_op{}, b); m.add_instruction(pass_op{}, b);
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().broadcasted()); EXPECT(m.get_output_shapes().back().broadcasted());
run_pass(p); run_pass(m);
EXPECT(p.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().broadcasted()); EXPECT(not m.get_output_shapes().back().broadcasted());
} }
TEST_CASE(after_param_transpose) TEST_CASE(after_param_transpose)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto l = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
auto l = mm->add_parameter("2x2", {migraphx::shape::float_type, {2, 2}}); EXPECT(m.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().transposed());
EXPECT(not p.get_output_shapes().back().transposed()); auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); m.add_instruction(pass_op{}, t);
mm->add_instruction(pass_op{}, t); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().transposed());
EXPECT(p.get_output_shapes().back().transposed()); run_pass(m);
run_pass(p); EXPECT(m.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().transposed());
EXPECT(not p.get_output_shapes().back().transposed());
} }
TEST_CASE(after_param_broadcast) TEST_CASE(after_param_broadcast)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto l1 = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
auto l1 = mm->add_parameter("2x2", {migraphx::shape::float_type, {2, 2}}); auto l2 = m.add_parameter("2", {migraphx::shape::float_type, {2}});
auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {2}}); EXPECT(m.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().broadcasted());
EXPECT(not p.get_output_shapes().back().broadcasted()); auto b = m.add_instruction(
auto b = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2);
mm->add_instruction(pass_op{}, b); m.add_instruction(pass_op{}, b);
EXPECT(not p.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().broadcasted()); EXPECT(m.get_output_shapes().back().broadcasted());
run_pass(p); run_pass(m);
EXPECT(p.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().broadcasted()); EXPECT(not m.get_output_shapes().back().broadcasted());
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -5,148 +5,132 @@ ...@@ -5,148 +5,132 @@
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::decompose{}}); }
{
migraphx::run_passes(*p.get_main_module(), {migraphx::decompose{}});
}
TEST_CASE(dot_add) TEST_CASE(dot_add)
{ {
migraphx::program p1; migraphx::module m1;
{ {
auto* mm1 = p1.get_main_module(); auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto dot = m1.add_instruction(migraphx::make_op("dot"), x, y, z);
auto dot = mm1->add_instruction(migraphx::make_op("dot"), x, y, z); m1.add_instruction(migraphx::make_op("identity"), dot);
mm1->add_instruction(migraphx::make_op("identity"), dot);
} }
run_pass(p1); run_pass(m1);
migraphx::program p2; migraphx::module m2;
{ {
auto* mm2 = p2.get_main_module(); auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto dot = auto add = m2.add_instruction(migraphx::make_op("add"), dot, z);
mm2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); m2.add_instruction(migraphx::make_op("identity"), add);
auto add = mm2->add_instruction(migraphx::make_op("add"), dot, z);
mm2->add_instruction(migraphx::make_op("identity"), add);
} }
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(dot_add_beta_float) TEST_CASE(dot_add_beta_float)
{ {
migraphx::program p1; migraphx::module m1;
{ {
auto* mm1 = p1.get_main_module(); auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto dot =
auto dot = mm1->add_instruction( m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z); m1.add_instruction(migraphx::make_op("identity"), dot);
mm1->add_instruction(migraphx::make_op("identity"), dot);
} }
run_pass(p1); run_pass(m1);
migraphx::program p2; migraphx::module m2;
{ {
auto* mm2 = p2.get_main_module(); auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto dot = auto beta =
mm2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto beta = mm2->add_literal( auto beta_broadcast = m2.add_instruction(
migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto beta_broadcast = mm2->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto mul = mm2->add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = mm2->add_instruction(migraphx::make_op("add"), dot, mul); auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
mm2->add_instruction(migraphx::make_op("identity"), add); m2.add_instruction(migraphx::make_op("identity"), add);
} }
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(dot_add_beta_half) TEST_CASE(dot_add_beta_half)
{ {
migraphx::program p1; migraphx::module m1;
{ {
auto* mm1 = p1.get_main_module(); auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto dot =
auto dot = mm1->add_instruction( m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z); m1.add_instruction(migraphx::make_op("identity"), dot);
mm1->add_instruction(migraphx::make_op("identity"), dot);
} }
run_pass(p1); run_pass(m1);
migraphx::program p2; migraphx::module m2;
{ {
auto* mm2 = p2.get_main_module(); auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}}); auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto dot =
mm2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta = auto beta =
mm2->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}}); m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
auto beta_broadcast = mm2->add_instruction( auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto mul = mm2->add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = mm2->add_instruction(migraphx::make_op("add"), dot, mul); auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
mm2->add_instruction(migraphx::make_op("identity"), add); m2.add_instruction(migraphx::make_op("identity"), add);
} }
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(dot_add_beta_double) TEST_CASE(dot_add_beta_double)
{ {
migraphx::program p1; migraphx::module m1;
{ {
auto* mm1 = p1.get_main_module(); auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto dot =
auto dot = mm1->add_instruction( m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z); m1.add_instruction(migraphx::make_op("identity"), dot);
mm1->add_instruction(migraphx::make_op("identity"), dot);
} }
run_pass(p1); run_pass(m1);
migraphx::program p2; migraphx::module m2;
{ {
auto* mm2 = p2.get_main_module(); auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto x = mm2->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = mm2->add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = mm2->add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto dot = auto beta =
mm2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
auto beta = mm2->add_literal( auto beta_broadcast = m2.add_instruction(
migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
auto beta_broadcast = mm2->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto mul = mm2->add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = mm2->add_instruction(migraphx::make_op("add"), dot, mul); auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
mm2->add_instruction(migraphx::make_op("identity"), add); m2.add_instruction(migraphx::make_op("identity"), add);
} }
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(dot_add_beta_int) TEST_CASE(dot_add_beta_int)
{ {
migraphx::program p1; migraphx::module m1;
{ {
auto* mm1 = p1.get_main_module(); auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto x = mm1->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto y = mm1->add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = mm1->add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto dot =
auto dot = mm1->add_instruction( m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z); m1.add_instruction(migraphx::make_op("identity"), dot);
mm1->add_instruction(migraphx::make_op("identity"), dot);
} }
migraphx::program p2 = p1; migraphx::module m2 = m1;
run_pass(p1); run_pass(m1);
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -6,11 +6,10 @@ ...@@ -6,11 +6,10 @@
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p, std::size_t align = 32) void run_pass(migraphx::module& m, std::size_t align = 32)
{ {
migraphx::run_passes( migraphx::run_passes(
*p.get_main_module(), m, {migraphx::eliminate_allocation{"allocate", align}, migraphx::dead_code_elimination{}});
{migraphx::eliminate_allocation{"allocate", align}, migraphx::dead_code_elimination{}});
} }
struct allocate struct allocate
...@@ -39,78 +38,74 @@ struct allocate ...@@ -39,78 +38,74 @@ struct allocate
TEST_CASE(basic) TEST_CASE(basic)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto a1 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {8}}});
auto a1 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {8}}}); auto m1 = m.add_instruction(pass_op{}, a1);
auto p1 = mm->add_instruction(pass_op{}, a1);
auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {40}}}); auto a2 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {40}}});
auto p2 = mm->add_instruction(pass_op{}, a2, p1); auto m2 = m.add_instruction(pass_op{}, a2, m1);
auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); auto a3 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
mm->add_instruction(pass_op{}, a3, p2); m.add_instruction(pass_op{}, a3, m2);
run_pass(p); run_pass(m);
EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}}); EXPECT(m.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
EXPECT(p.get_parameter_shape("memory").bytes() == (8 * 4 + 40 * 4 + 200 * 4)); EXPECT(m.get_parameter_shape("memory").bytes() == (8 * 4 + 40 * 4 + 200 * 4));
} }
TEST_CASE(aligned) TEST_CASE(aligned)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto a1 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto a1 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}}); auto m1 = m.add_instruction(pass_op{}, a1);
auto p1 = mm->add_instruction(pass_op{}, a1);
auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}}); auto a2 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = mm->add_instruction(pass_op{}, a2, p1); auto m2 = m.add_instruction(pass_op{}, a2, m1);
auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); auto a3 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
mm->add_instruction(pass_op{}, a3, p2); m.add_instruction(pass_op{}, a3, m2);
run_pass(p); run_pass(m);
EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}}); EXPECT(m.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
EXPECT(p.get_parameter_shape("memory").bytes() == (32 + 32 + 200 * 4)); EXPECT(m.get_parameter_shape("memory").bytes() == (32 + 32 + 200 * 4));
} }
TEST_CASE(unaligned) TEST_CASE(unaligned)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto a1 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto a1 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}}); auto m1 = m.add_instruction(pass_op{}, a1);
auto p1 = mm->add_instruction(pass_op{}, a1);
auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}}); auto a2 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = mm->add_instruction(pass_op{}, a2, p1); auto m2 = m.add_instruction(pass_op{}, a2, m1);
auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); auto a3 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
mm->add_instruction(pass_op{}, a3, p2); m.add_instruction(pass_op{}, a3, m2);
run_pass(p, 1); run_pass(m, 1);
EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}}); EXPECT(m.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4)); EXPECT(m.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4));
} }
TEST_CASE(float_aligned) TEST_CASE(float_aligned)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto a1 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto a1 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}}); auto m1 = m.add_instruction(pass_op{}, a1);
auto p1 = mm->add_instruction(pass_op{}, a1);
auto a2 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}}); auto a2 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = mm->add_instruction(pass_op{}, a2, p1); auto m2 = m.add_instruction(pass_op{}, a2, m1);
auto a3 = mm->add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); auto a3 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
mm->add_instruction(pass_op{}, a3, p2); m.add_instruction(pass_op{}, a3, m2);
run_pass(p, 4); run_pass(m, 4);
EXPECT(p.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}}); EXPECT(m.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}});
EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4)); EXPECT(m.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -6,151 +6,140 @@ ...@@ -6,151 +6,140 @@
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::module& m)
{ {
migraphx::run_passes( migraphx::run_passes(
*p.get_main_module(), m, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}});
{migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}});
} }
TEST_CASE(cse_test1) TEST_CASE(cse_test1)
{ {
migraphx::program p1; migraphx::module m1;
{ {
auto* mm1 = p1.get_main_module(); auto one = m1.add_literal(1);
auto one = mm1->add_literal(1); auto two = m1.add_literal(2);
auto two = mm1->add_literal(2); auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, two);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, two); auto sum2 = m1.add_instruction(migraphx::make_op("add"), one, two);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), one, two); auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2);
auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2); m1.add_instruction(pass_op{}, sum3);
mm1->add_instruction(pass_op{}, sum3);
} }
run_pass(p1); run_pass(m1);
migraphx::program p2; migraphx::module m2;
{ {
auto* mm2 = p2.get_main_module(); auto one = m2.add_literal(1);
auto one = mm2->add_literal(1); auto two = m2.add_literal(2);
auto two = mm2->add_literal(2); auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two); auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, sum1);
auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum1); m2.add_instruction(pass_op{}, sum3);
mm2->add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(cse_test2) TEST_CASE(cse_test2)
{ {
migraphx::program p1; migraphx::module m1;
{ {
auto* mm1 = p1.get_main_module(); auto one = m1.add_literal(1);
auto one = mm1->add_literal(1); auto two = m1.add_literal(2);
auto two = mm1->add_literal(2); auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, two);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, two); auto sum2 = m1.add_instruction(migraphx::make_op("add"), two, one);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), two, one); auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2);
auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2); m1.add_instruction(pass_op{}, sum3);
mm1->add_instruction(pass_op{}, sum3);
} }
run_pass(p1); run_pass(m1);
migraphx::program p2; migraphx::module m2;
{ {
auto* mm2 = p2.get_main_module(); auto one = m2.add_literal(1);
auto one = mm2->add_literal(1); auto two = m2.add_literal(2);
auto two = mm2->add_literal(2); auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two); auto sum2 = m2.add_instruction(migraphx::make_op("add"), two, one);
auto sum2 = mm2->add_instruction(migraphx::make_op("add"), two, one); auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, sum2);
auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum2); m2.add_instruction(pass_op{}, sum3);
mm2->add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(cse_test3) TEST_CASE(cse_test3)
{ {
migraphx::program p1; migraphx::module m1;
{ {
auto* mm1 = p1.get_main_module(); auto one = m1.add_literal(1);
auto one = mm1->add_literal(1); auto two = m1.add_literal(1);
auto two = mm1->add_literal(1); auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, two);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, two); auto sum2 = m1.add_instruction(migraphx::make_op("add"), two, one);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), two, one); auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2);
auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2); m1.add_instruction(pass_op{}, sum3);
mm1->add_instruction(pass_op{}, sum3);
} }
run_pass(p1); run_pass(m1);
migraphx::program p2; migraphx::module m2;
{ {
auto* mm2 = p2.get_main_module(); auto one = m2.add_literal(1);
auto one = mm2->add_literal(1); auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, one);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, one); auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, sum1);
auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum1); m2.add_instruction(pass_op{}, sum3);
mm2->add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(cse_test4) TEST_CASE(cse_test4)
{ {
migraphx::program p1; migraphx::module m1;
{ {
auto* mm1 = p1.get_main_module(); auto one = m1.add_literal(1);
auto one = mm1->add_literal(1); auto two = m1.add_literal(1);
auto two = mm1->add_literal(1); auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, two);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, two); auto sum2 = m1.add_instruction(migraphx::make_op("add"), two, one);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), two, one); auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, one);
auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, one); auto sum4 = m1.add_instruction(migraphx::make_op("add"), sum2, two);
auto sum4 = mm1->add_instruction(migraphx::make_op("add"), sum2, two); auto sum5 = m1.add_instruction(migraphx::make_op("add"), sum4, sum3);
auto sum5 = mm1->add_instruction(migraphx::make_op("add"), sum4, sum3); m1.add_instruction(pass_op{}, sum5);
mm1->add_instruction(pass_op{}, sum5);
} }
run_pass(p1); run_pass(m1);
migraphx::program p2; migraphx::module m2;
{ {
auto* mm2 = p2.get_main_module(); auto one = m2.add_literal(1);
auto one = mm2->add_literal(1); auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, one);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, one); auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, one);
auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum1, one); auto sum5 = m2.add_instruction(migraphx::make_op("add"), sum3, sum3);
auto sum5 = mm2->add_instruction(migraphx::make_op("add"), sum3, sum3); m2.add_instruction(pass_op{}, sum5);
mm2->add_instruction(pass_op{}, sum5);
} }
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(cse_test_literal) TEST_CASE(cse_test_literal)
{ {
migraphx::program p1; migraphx::module m1;
{ {
auto* mm1 = p1.get_main_module(); auto six1 = m1.add_literal(6);
auto six1 = mm1->add_literal(6); auto zero1 = m1.add_literal(0);
auto zero1 = mm1->add_literal(0); auto six2 = m1.add_literal(6);
auto six2 = mm1->add_literal(6); auto zero2 = m1.add_literal(0);
auto zero2 = mm1->add_literal(0); auto six3 = m1.add_literal(6);
auto six3 = mm1->add_literal(6); auto zero3 = m1.add_literal(0);
auto zero3 = mm1->add_literal(0);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), six1, zero1); auto sum1 = m1.add_instruction(migraphx::make_op("add"), six1, zero1);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), six2, zero2); auto sum2 = m1.add_instruction(migraphx::make_op("add"), six2, zero2);
auto sum3 = mm1->add_instruction(migraphx::make_op("add"), six3, zero3); auto sum3 = m1.add_instruction(migraphx::make_op("add"), six3, zero3);
auto sum4 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2); auto sum4 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2);
auto sum5 = mm1->add_instruction(migraphx::make_op("add"), sum3, sum4); auto sum5 = m1.add_instruction(migraphx::make_op("add"), sum3, sum4);
mm1->add_instruction(pass_op{}, sum5); m1.add_instruction(pass_op{}, sum5);
} }
run_pass(p1); run_pass(m1);
migraphx::program p2; migraphx::module m2;
{ {
auto* mm2 = p2.get_main_module(); auto six = m2.add_literal(6);
auto six = mm2->add_literal(6); auto zero = m2.add_literal(0);
auto zero = mm2->add_literal(0); auto sum1 = m2.add_instruction(migraphx::make_op("add"), six, zero);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), six, zero); auto sum2 = m2.add_instruction(migraphx::make_op("add"), sum1, sum1);
auto sum2 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum1); auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, sum2);
auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum2); m2.add_instruction(pass_op{}, sum3);
mm2->add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
This diff is collapsed.
...@@ -6,138 +6,128 @@ ...@@ -6,138 +6,128 @@
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::module& m)
{ {
migraphx::run_passes(*p.get_main_module(), migraphx::run_passes(m, {migraphx::eliminate_contiguous{}, migraphx::dead_code_elimination{}});
{migraphx::eliminate_contiguous{}, migraphx::dead_code_elimination{}});
} }
TEST_CASE(standard_op) TEST_CASE(standard_op)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}}); auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); m.add_instruction(pass_standard_op{}, c);
mm->add_instruction(pass_standard_op{}, c); auto count = std::distance(m.begin(), m.end());
auto count = std::distance(mm->begin(), mm->end()); run_pass(m);
run_pass(p); EXPECT(std::distance(m.begin(), m.end()) == count);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
} }
TEST_CASE(standard_op_const) TEST_CASE(standard_op_const)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto l = m.add_literal(get_2x2());
auto l = mm->add_literal(get_2x2()); auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); m.add_instruction(pass_standard_op{}, c);
mm->add_instruction(pass_standard_op{}, c); run_pass(m);
run_pass(p); EXPECT(std::distance(m.begin(), m.end()) == 2);
EXPECT(std::distance(mm->begin(), mm->end()) == 2);
} }
TEST_CASE(non_standard_op) TEST_CASE(non_standard_op)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}}); auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); m.add_instruction(pass_op{}, c);
mm->add_instruction(pass_op{}, c); auto count = std::distance(m.begin(), m.end());
auto count = std::distance(mm->begin(), mm->end()); run_pass(m);
run_pass(p); EXPECT(std::distance(m.begin(), m.end()) == count);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
} }
TEST_CASE(non_standard_op_const) TEST_CASE(non_standard_op_const)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto l = m.add_literal(get_2x2());
auto l = mm->add_literal(get_2x2()); auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); m.add_instruction(pass_op{}, c);
mm->add_instruction(pass_op{}, c); run_pass(m);
run_pass(p); EXPECT(std::distance(m.begin(), m.end()) == 2);
EXPECT(std::distance(mm->begin(), mm->end()) == 2);
} }
TEST_CASE(transpose_gemm) TEST_CASE(transpose_gem)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto l = m.add_literal(get_2x2());
auto l = mm->add_literal(get_2x2()); auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); auto ic = m.add_instruction(migraphx::make_op("identity"), c);
auto ic = mm->add_instruction(migraphx::make_op("identity"), c); m.add_instruction(migraphx::make_op("dot"), ic, l);
mm->add_instruction(migraphx::make_op("dot"), ic, l); auto count = std::distance(m.begin(), m.end());
auto count = std::distance(mm->begin(), mm->end()); run_pass(m);
run_pass(p); EXPECT(std::distance(m.begin(), m.end()) == (count - 1));
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
} }
TEST_CASE(transpose_standard_op) TEST_CASE(transpose_standard_op)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto l = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}}); auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); auto sn = m.add_instruction(migraphx::make_op("sin"), c);
auto sn = mm->add_instruction(migraphx::make_op("sin"), c); m.add_instruction(pass_standard_op{}, sn);
mm->add_instruction(pass_standard_op{}, sn); auto count = std::distance(m.begin(), m.end());
auto count = std::distance(mm->begin(), mm->end()); run_pass(m);
run_pass(p); EXPECT(std::distance(m.begin(), m.end()) == count);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
} }
TEST_CASE(transpose_standard_op_const) TEST_CASE(transpose_standard_op_const)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto l = m.add_literal(get_2x2());
auto l = mm->add_literal(get_2x2()); auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); auto sn = m.add_instruction(migraphx::make_op("sin"), c);
auto sn = mm->add_instruction(migraphx::make_op("sin"), c); m.add_instruction(pass_standard_op{}, sn);
mm->add_instruction(pass_standard_op{}, sn); run_pass(m);
run_pass(p); EXPECT(std::distance(m.begin(), m.end()) == 3);
EXPECT(std::distance(mm->begin(), mm->end()) == 3);
} }
TEST_CASE(no_packed_unary_op) TEST_CASE(no_packed_unary_op)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto l = m.add_literal(get_2x2());
auto l = mm->add_literal(get_2x2()); auto t = m.add_instruction(
auto t = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), l); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), l);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto sn = mm->add_instruction(migraphx::make_op("sin"), c); auto sn = m.add_instruction(migraphx::make_op("sin"), c);
mm->add_instruction(pass_standard_op{}, sn); m.add_instruction(pass_standard_op{}, sn);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(m.begin(), m.end());
run_pass(p); run_pass(m);
EXPECT(std::distance(mm->begin(), mm->end()) == count - 1); EXPECT(std::distance(m.begin(), m.end()) == count - 1);
} }
TEST_CASE(non_standard_return_input) TEST_CASE(non_standard_return_input)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto l = m.add_literal(get_2x2());
auto l = mm->add_literal(get_2x2()); auto tl = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto tl = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto c = m.add_instruction(migraphx::make_op("contiguous"), tl);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), tl); m.add_return({c});
mm->add_return({c}); auto count = std::distance(m.begin(), m.end());
auto count = std::distance(mm->begin(), mm->end()); run_pass(m);
run_pass(p); EXPECT(std::distance(m.begin(), m.end()) == count);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -8,105 +8,98 @@ ...@@ -8,105 +8,98 @@
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::module& m)
{ {
migraphx::run_passes(*p.get_main_module(), migraphx::run_passes(m, {migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}});
{migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}});
} }
migraphx::instruction_ref migraphx::instruction_ref
create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::program& p) create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::module& m)
{ {
size_t f[2] = {1, 1}; size_t f[2] = {1, 1};
std::vector<int32_t> weights(channels * f[0] * f[1]); std::vector<int32_t> weights(channels * f[0] * f[1]);
auto* mm = p.get_main_module();
migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights}); auto l_weights = m.add_literal(migraphx::literal{s_weights, weights});
return mm->add_instruction(migraphx::make_op("im2col"), l_img, l_weights); return m.add_instruction(migraphx::make_op("im2col"), l_img, l_weights);
} }
migraphx::instruction_ref migraphx::instruction_ref
create_conv(migraphx::instruction_ref& l_img, create_conv(migraphx::instruction_ref& l_img,
size_t channels, size_t channels,
migraphx::program& p, migraphx::module& m,
migraphx::op::padding_mode_t padding_mode = migraphx::op::padding_mode_t::default_) migraphx::op::padding_mode_t padding_mode = migraphx::op::padding_mode_t::default_)
{ {
migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}}; migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}};
std::vector<int32_t> weights(4 * channels * 3 * 3); std::vector<int32_t> weights(4 * channels * 3 * 3);
auto* mm = p.get_main_module(); auto l_weights = m.add_literal(migraphx::literal{s_weights, weights});
auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights});
migraphx::op::convolution op; migraphx::op::convolution op;
op.padding_mode = padding_mode; op.padding_mode = padding_mode;
return mm->add_instruction(op, l_img, l_weights); return m.add_instruction(op, l_img, l_weights);
} }
TEST_CASE(rewrite_pad) TEST_CASE(rewrite_pad)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module();
size_t img_dim[2] = {2, 2}; size_t img_dim[2] = {2, 2};
size_t channels = 1; size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]); std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0); std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}}; migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = mm->add_literal(migraphx::literal{s_img, input}); auto l_img = m.add_literal(migraphx::literal{s_img, input});
auto padded_img = auto padded_img =
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 1, 1, 0, 0, 1, 1}}}), l_img); m.add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 1, 1, 0, 0, 1, 1}}}), l_img);
auto l0 = create_im2col(padded_img, channels, p); auto l0 = create_im2col(padded_img, channels, m);
auto l1 = create_conv(padded_img, channels, p); auto l1 = create_conv(padded_img, channels, m);
auto l2 = mm->add_instruction(migraphx::make_op("pooling", {{"mode", "max"}}), padded_img); auto l2 = m.add_instruction(migraphx::make_op("pooling", {{"mode", "max"}}), padded_img);
mm->add_instruction(migraphx::make_op("identity"), l0, l1, l2); m.add_instruction(migraphx::make_op("identity"), l0, l1, l2);
auto s0 = l0->get_shape(); auto s0 = l0->get_shape();
auto s1 = l1->get_shape(); auto s1 = l1->get_shape();
auto s2 = l2->get_shape(); auto s2 = l2->get_shape();
run_pass(p); run_pass(m);
EXPECT(l0->get_shape() == s0); EXPECT(l0->get_shape() == s0);
EXPECT(l1->get_shape() == s1); EXPECT(l1->get_shape() == s1);
EXPECT(l2->get_shape() == s2); EXPECT(l2->get_shape() == s2);
auto op0 = l0->get_operator().to_value(); auto op0 = l0->get_operator().to_value();
auto op1 = l1->get_operator().to_value(); auto om1 = l1->get_operator().to_value();
auto op2 = l2->get_operator().to_value(); auto om2 = l2->get_operator().to_value();
EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1}); EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1});
EXPECT(op1["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1}); EXPECT(om1["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1});
EXPECT(op2["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1}); EXPECT(om2["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1});
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(
return ins.name() == "pad"; m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}));
} }
TEST_CASE(rewrite_pad_im2col_asymmetric) TEST_CASE(rewrite_pad_im2col_asymetric)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module();
size_t img_dim[2] = {2, 2}; size_t img_dim[2] = {2, 2};
size_t channels = 1; size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]); std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0); std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}}; migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = mm->add_literal(migraphx::literal{s_img, input}); auto l_img = m.add_literal(migraphx::literal{s_img, input});
auto padded_img = auto padded_img =
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 2, 2}}}), l_img); m.add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 2, 2}}}), l_img);
auto l0 = create_im2col(padded_img, channels, p); auto l0 = create_im2col(padded_img, channels, m);
auto s0 = l0->get_shape(); auto s0 = l0->get_shape();
run_pass(p); run_pass(m);
EXPECT(l0->get_shape() == s0); EXPECT(l0->get_shape() == s0);
auto op0 = l0->get_operator().to_value(); auto op0 = l0->get_operator().to_value();
EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{0, 0}); EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{0, 0});
run_pass(p); run_pass(m);
EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { EXPECT(std::any_of(
return ins.name() == "pad"; m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
This diff is collapsed.
...@@ -38,130 +38,125 @@ struct normalize_test_op ...@@ -38,130 +38,125 @@ struct normalize_test_op
} }
}; };
void run_pass(migraphx::program& p) void run_pass(migraphx::module& m)
{ {
migraphx::run_passes(*p.get_main_module(), migraphx::run_passes(m, {migraphx::normalize_ops{}, migraphx::dead_code_elimination{}});
{migraphx::normalize_ops{}, migraphx::dead_code_elimination{}});
} }
migraphx::program create_gather(int64_t axis) migraphx::module create_gather(int64_t axis)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape sd{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape si{migraphx::shape::int64_type, {2, 3}}; migraphx::shape si{migraphx::shape::int64_type, {2, 3}};
auto di = mm->add_parameter("data", sd); auto di = m.add_parameter("data", sd);
auto ii = mm->add_parameter("ind", si); auto ii = m.add_parameter("ind", si);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), di, ii); auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", axis}}), di, ii);
mm->add_return({r}); m.add_return({r});
return p; return m;
} }
TEST_CASE(gather_test) TEST_CASE(gather_test)
{ {
auto p1 = create_gather(-3); auto m1 = create_gather(-3);
auto p2 = create_gather(0); auto m2 = create_gather(0);
run_pass(p1); run_pass(m1);
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(gather_test_1) TEST_CASE(gather_test_1)
{ {
auto p1 = create_gather(1); auto m1 = create_gather(1);
auto p2 = create_gather(1); auto m2 = create_gather(1);
run_pass(p1); run_pass(m1);
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
migraphx::program create_reduce_mean(const std::vector<int64_t>& axes) migraphx::module create_reduce_mean(const std::vector<int64_t>& axes)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto si = mm->add_parameter("data", s); auto si = m.add_parameter("data", s);
auto r = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), si); auto r = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), si);
mm->add_return({r}); m.add_return({r});
return p; return m;
} }
TEST_CASE(reduce_mean_test) TEST_CASE(reduce_mean_test)
{ {
migraphx::program p1 = create_reduce_mean({0, 1, -1}); migraphx::module m1 = create_reduce_mean({0, 1, -1});
migraphx::program p2 = create_reduce_mean({0, 1, 3}); migraphx::module m2 = create_reduce_mean({0, 1, 3});
run_pass(p1); run_pass(m1);
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(reduce_mean_test_1) TEST_CASE(reduce_mean_test_1)
{ {
migraphx::program p1 = create_reduce_mean({0, 1, 2}); migraphx::module m1 = create_reduce_mean({0, 1, 2});
migraphx::program p2 = create_reduce_mean({0, 1, 2}); migraphx::module m2 = create_reduce_mean({0, 1, 2});
run_pass(p1); run_pass(m1);
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
migraphx::program create_slice(const std::vector<int64_t>& axes, migraphx::module create_slice(const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts, const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends) const std::vector<int64_t>& ends)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto si = mm->add_parameter("data", s); auto si = m.add_parameter("data", s);
auto r = mm->add_instruction( auto r = m.add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), si); migraphx::make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), si);
mm->add_return({r}); m.add_return({r});
return p; return m;
} }
TEST_CASE(slice_test) TEST_CASE(slice_test)
{ {
migraphx::program p1 = create_slice({0, 1, -1}, {-5, 1, -3}, {2, 2, 8}); migraphx::module m1 = create_slice({0, 1, -1}, {-5, 1, -3}, {2, 2, 8});
migraphx::program p2 = create_slice({0, 1, 3}, {0, 1, 2}, {2, 2, 5}); migraphx::module m2 = create_slice({0, 1, 3}, {0, 1, 2}, {2, 2, 5});
run_pass(p1); run_pass(m1);
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(slice_test_1) TEST_CASE(slice_test_1)
{ {
migraphx::program p1 = create_slice({0, 1, 3}, {0, 1, -3}, {1, 2, 5}); migraphx::module m1 = create_slice({0, 1, 3}, {0, 1, -3}, {1, 2, 5});
migraphx::program p2 = create_slice({0, 1, 3}, {0, 1, 2}, {1, 2, 5}); migraphx::module m2 = create_slice({0, 1, 3}, {0, 1, 2}, {1, 2, 5});
run_pass(p1); run_pass(m1);
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
migraphx::program create_test_op(const std::vector<int64_t>& axes) migraphx::module create_test_op(const std::vector<int64_t>& axes)
{ {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape sd{migraphx::shape::float_type, {2, 3, 4}};
auto di = mm->add_parameter("data", sd); auto di = m.add_parameter("data", sd);
auto r = mm->add_instruction(normalize_test_op{axes}, di); auto r = m.add_instruction(normalize_test_op{axes}, di);
mm->add_return({r}); m.add_return({r});
return p; return m;
} }
TEST_CASE(test_op) TEST_CASE(test_op)
{ {
std::vector<int64_t> axes1 = {-4, 5}; std::vector<int64_t> axes1 = {-4, 5};
auto p1 = create_test_op(axes1); auto m1 = create_test_op(axes1);
std::vector<int64_t> axes2 = {1, 2}; std::vector<int64_t> axes2 = {1, 2};
auto p2 = create_test_op(axes2); auto m2 = create_test_op(axes2);
run_pass(p1); run_pass(m1);
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -6,122 +6,109 @@ ...@@ -6,122 +6,109 @@
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::module& m)
{ {
migraphx::run_passes(*p.get_main_module(), migraphx::run_passes(m, {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
{migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
} }
TEST_CASE(const_add) TEST_CASE(const_add)
{ {
migraphx::program p1; migraphx::module m1;
auto* mm1 = p1.get_main_module(); auto one = m1.add_literal(1);
auto one = mm1->add_literal(1); auto two = m1.add_literal(2);
auto two = mm1->add_literal(2); auto sum = m1.add_instruction(migraphx::make_op("add"), one, two);
auto sum = mm1->add_instruction(migraphx::make_op("add"), one, two); m1.add_instruction(pass_op{}, sum);
mm1->add_instruction(pass_op{}, sum); run_pass(m1);
run_pass(p1);
migraphx::program p2; migraphx::module m2;
auto* mm2 = p2.get_main_module(); auto total = m2.add_literal(3);
auto total = mm2->add_literal(3); m2.add_instruction(pass_op{}, total);
mm2->add_instruction(pass_op{}, total); EXPECT(m1 == m2);
EXPECT(p1 == p2);
} }
TEST_CASE(const_add_parameter) TEST_CASE(const_add_parameter)
{ {
migraphx::program p1; migraphx::module m1;
auto* mm1 = p1.get_main_module(); auto one = m1.add_parameter("one", {migraphx::shape::int32_type, {1}});
auto one = mm1->add_parameter("one", {migraphx::shape::int32_type, {1}}); auto two = m1.add_literal(2);
auto two = mm1->add_literal(2); auto sum = m1.add_instruction(migraphx::make_op("add"), one, two);
auto sum = mm1->add_instruction(migraphx::make_op("add"), one, two); m1.add_instruction(pass_op{}, sum);
mm1->add_instruction(pass_op{}, sum); run_pass(m1);
run_pass(p1);
migraphx::program p2; migraphx::module m2;
auto* mm2 = p2.get_main_module(); auto total = m2.add_literal(3);
auto total = mm2->add_literal(3); m2.add_instruction(pass_op{}, total);
mm2->add_instruction(pass_op{}, total); EXPECT(m1 != m2);
EXPECT(p1 != p2);
} }
TEST_CASE(const_multiadd) TEST_CASE(const_multiadd)
{ {
migraphx::program p1; migraphx::module m1;
auto* mm1 = p1.get_main_module(); auto one = m1.add_literal(1);
auto one = mm1->add_literal(1); auto two = m1.add_literal(2);
auto two = mm1->add_literal(2); auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, two);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, two); auto sum2 = m1.add_instruction(migraphx::make_op("add"), sum1, two);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), sum1, two); m1.add_instruction(pass_op{}, sum2);
mm1->add_instruction(pass_op{}, sum2); run_pass(m1);
run_pass(p1);
migraphx::program p2; migraphx::module m2;
auto* mm2 = p2.get_main_module(); auto total = m2.add_literal(5);
auto total = mm2->add_literal(5); m2.add_instruction(pass_op{}, total);
mm2->add_instruction(pass_op{}, total); EXPECT(m1 == m2);
EXPECT(p1 == p2);
} }
TEST_CASE(const_add_mul) TEST_CASE(const_add_mul)
{ {
migraphx::program p1; migraphx::module m1;
auto* mm1 = p1.get_main_module(); auto one = m1.add_literal(1);
auto one = mm1->add_literal(1); auto two = m1.add_literal(2);
auto two = mm1->add_literal(2); auto mul = m1.add_instruction(migraphx::make_op("mul"), two, two);
auto mul = mm1->add_instruction(migraphx::make_op("mul"), two, two); auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, mul);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, mul); auto sum2 = m1.add_instruction(migraphx::make_op("add"), sum1, two);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), sum1, two); m1.add_instruction(pass_op{}, sum2);
mm1->add_instruction(pass_op{}, sum2); run_pass(m1);
run_pass(p1);
migraphx::program p2; migraphx::module m2;
auto* mm2 = p2.get_main_module(); auto total = m2.add_literal(7);
auto total = mm2->add_literal(7); m2.add_instruction(pass_op{}, total);
mm2->add_instruction(pass_op{}, total); EXPECT(m1 == m2);
EXPECT(p1 == p2);
} }
TEST_CASE(const_add_scalar) TEST_CASE(const_add_scalar)
{ {
migraphx::program p1; migraphx::module m1;
auto* mm1 = p1.get_main_module(); auto one = m1.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}),
auto one = mm1->add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}), m1.add_literal(1));
mm1->add_literal(1)); auto two = m1.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}),
auto two = mm1->add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}), m1.add_literal(2));
mm1->add_literal(2)); auto sum = m1.add_instruction(migraphx::make_op("add"), one, two);
auto sum = mm1->add_instruction(migraphx::make_op("add"), one, two); m1.add_instruction(pass_op{}, sum);
mm1->add_instruction(pass_op{}, sum); run_pass(m1);
run_pass(p1);
migraphx::program p2; migraphx::module m2;
auto* mm2 = p2.get_main_module();
auto total = auto total =
mm2->add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 2}}, {3, 3, 3, 3}}); m2.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 2}}, {3, 3, 3, 3}});
mm2->add_instruction(pass_op{}, total); m2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(const_scalar) TEST_CASE(const_scalar)
{ {
migraphx::program p1; migraphx::module m1;
{ {
auto* mm1 = p1.get_main_module(); auto one = m1.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}),
auto one = mm1->add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}), m1.add_literal(1));
mm1->add_literal(1)); m1.add_instruction(pass_op{}, one);
mm1->add_instruction(pass_op{}, one);
} }
run_pass(p1); run_pass(m1);
migraphx::program p2; migraphx::module m2;
{ {
auto* mm2 = p2.get_main_module(); auto one = m2.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}),
auto one = mm2->add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}), m2.add_literal(1));
mm2->add_literal(1)); m2.add_instruction(pass_op{}, one);
mm2->add_instruction(pass_op{}, one);
} }
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -11,49 +11,46 @@ ...@@ -11,49 +11,46 @@
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
bool is_pooling(migraphx::instruction& ins) { return ins.name() == "pooling"; } bool is_pooling(migraphx::instruction& ins) { return ins.name() == "pooling"; }
static void opt_pooling(migraphx::program& prog) static void opt_pooling(migraphx::module& m)
{ {
auto* mm = prog.get_main_module();
migraphx::rewrite_pooling rp; migraphx::rewrite_pooling rp;
migraphx::dead_code_elimination dce; migraphx::dead_code_elimination dce;
rp.apply(*mm); rp.apply(m);
dce.apply(*mm); dce.apply(m);
} }
TEST_CASE(rewrite_pooling_test) TEST_CASE(rewrite_pooling_test)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
auto pooling_program = [&](const std::string& mode) { auto pooling_program = [&](const std::string& mode) {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto input = m.add_parameter("x", s);
auto input = mm->add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling",
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", mode}, {{"mode", mode},
{"padding", {0, 0, 0}}, {"padding", {0, 0, 0}},
{"stride", {1, 1, 1}}, {"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}), {"lengths", {3, 4, 5}}}),
input); input);
mm->add_return({ret}); m.add_return({ret});
return p; return m;
}; };
auto opt_program = [&](const migraphx::operation& reduce_op) { auto opt_program = [&](const migraphx::operation& reduce_op) {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto input = m.add_parameter("x", s);
auto input = mm->add_parameter("x", s); auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), input);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), input); auto rdm = m.add_instruction(reduce_op, rsp);
auto rdm = mm->add_instruction(reduce_op, rsp);
auto ret = auto ret =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1, 1, 1}}}), rdm); m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1, 1, 1}}}), rdm);
mm->add_return({ret}); m.add_return({ret});
return p; return m;
}; };
auto test_rewrite = [&](const std::string& mode, const migraphx::operation& op) { auto test_rewrite = [&](const std::string& mode, const migraphx::operation& op) {
migraphx::program p1 = pooling_program(mode); migraphx::module m1 = pooling_program(mode);
migraphx::program p2 = opt_program(op); migraphx::module m2 = opt_program(op);
opt_pooling(p1); opt_pooling(m1);
EXPECT(p1 == p2); EXPECT(m1 == m2);
}; };
test_rewrite("average", migraphx::make_op("reduce_mean", {{"axes", {1}}})); test_rewrite("average", migraphx::make_op("reduce_mean", {{"axes", {1}}}));
...@@ -64,75 +61,72 @@ TEST_CASE(rewrite_avepooling_na1_test) ...@@ -64,75 +61,72 @@ TEST_CASE(rewrite_avepooling_na1_test)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
auto pooling_program = [&]() { auto pooling_program = [&]() {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto input = m.add_parameter("x", s);
auto input = mm->add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling",
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", "average"},
{"padding", {0, 1, 0}}, {"padding", {0, 1, 0}},
{"stride", {1, 1, 1}}, {"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}), {"lengths", {3, 4, 5}}}),
input); input);
mm->add_return({ret}); m.add_return({ret});
return p; return m;
}; };
migraphx::program p1 = pooling_program(); migraphx::module m1 = pooling_program();
migraphx::program p2 = p1; migraphx::module m2 = m1;
opt_pooling(p1); opt_pooling(m1);
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(rewrite_avepooling_na2_test) TEST_CASE(rewrite_avepooling_na2_test)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
auto pooling_program = [&]() { auto pooling_program = [&]() {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto input = m.add_parameter("x", s);
auto input = mm->add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling",
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", "average"},
{"padding", {0, 0, 0}}, {"padding", {0, 0, 0}},
{"stride", {1, 2, 1}}, {"stride", {1, 2, 1}},
{"lengths", {3, 4, 5}}}), {"lengths", {3, 4, 5}}}),
input); input);
mm->add_return({ret}); m.add_return({ret});
return p; return m;
}; };
migraphx::program p1 = pooling_program(); migraphx::module m1 = pooling_program();
migraphx::program p2 = p1; migraphx::module m2 = m1;
opt_pooling(p1); opt_pooling(m1);
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(rewrite_avepooling_na3_test) TEST_CASE(rewrite_avepooling_na3_test)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
auto pooling_program = [&]() { auto pooling_program = [&]() {
migraphx::program p; migraphx::module m;
auto* mm = p.get_main_module(); auto input = m.add_parameter("x", s);
auto input = mm->add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling",
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "max"}, {{"mode", "max"},
{"padding", {0, 0, 0}}, {"padding", {0, 0, 0}},
{"stride", {1, 1, 1}}, {"stride", {1, 1, 1}},
{"lengths", {3, 3, 5}}}), {"lengths", {3, 3, 5}}}),
input); input);
mm->add_return({ret}); m.add_return({ret});
return p; return m;
}; };
migraphx::program p1 = pooling_program(); migraphx::module m1 = pooling_program();
migraphx::program p2 = p1; migraphx::module m2 = m1;
opt_pooling(p1); opt_pooling(m1);
EXPECT(p1 == p2); EXPECT(m1 == m2);
} }
TEST_CASE(literal_rewrite_pooling_test) TEST_CASE(literal_rewrite_pooling_test)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment