Unverified Commit 8d21fdc9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor to use make_op almost everywhere (#696)

* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold
parent b5633c27
......@@ -2,14 +2,12 @@
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/verify.hpp>
bool is_pooling(migraphx::instruction& ins) { return ins.name() == "pooling"; }
......@@ -29,7 +27,11 @@ TEST_CASE(rewrite_pooling_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", s);
auto ret = mm->add_instruction(migraphx::op::pooling{mode, {0, 0, 0}, {1, 1, 1}, {3, 4, 5}},
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", mode},
{"padding", {0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}),
input);
mm->add_return({ret});
return p;
......@@ -39,9 +41,10 @@ TEST_CASE(rewrite_pooling_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", s);
auto rsp = mm->add_instruction(migraphx::op::reshape{{4, -1}}, input);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), input);
auto rdm = mm->add_instruction(reduce_op, rsp);
auto ret = mm->add_instruction(migraphx::op::reshape{{2, 2, 1, 1, 1}}, rdm);
auto ret =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1, 1, 1}}}), rdm);
mm->add_return({ret});
return p;
};
......@@ -53,8 +56,8 @@ TEST_CASE(rewrite_pooling_test)
EXPECT(p1 == p2);
};
test_rewrite("average", migraphx::op::reduce_mean{{1}});
test_rewrite("max", migraphx::op::reduce_max{{1}});
test_rewrite("average", migraphx::make_op("reduce_mean", {{"axes", {1}}}));
test_rewrite("max", migraphx::make_op("reduce_max", {{"axes", {1}}}));
}
TEST_CASE(rewrite_avepooling_na1_test)
......@@ -65,8 +68,12 @@ TEST_CASE(rewrite_avepooling_na1_test)
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", s);
auto ret = mm->add_instruction(
migraphx::op::pooling{"average", {0, 1, 0}, {1, 1, 1}, {3, 4, 5}}, input);
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 1, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}),
input);
mm->add_return({ret});
return p;
};
......@@ -86,8 +93,12 @@ TEST_CASE(rewrite_avepooling_na2_test)
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", s);
auto ret = mm->add_instruction(
migraphx::op::pooling{"average", {0, 0, 0}, {1, 2, 1}, {3, 4, 5}}, input);
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 0, 0}},
{"stride", {1, 2, 1}},
{"lengths", {3, 4, 5}}}),
input);
mm->add_return({ret});
return p;
};
......@@ -107,8 +118,12 @@ TEST_CASE(rewrite_avepooling_na3_test)
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", s);
auto ret = mm->add_instruction(
migraphx::op::pooling{"max", {0, 0, 0}, {1, 1, 1}, {3, 3, 5}}, input);
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "max"},
{"padding", {0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 3, 5}}}),
input);
mm->add_return({ret});
return p;
};
......@@ -131,7 +146,11 @@ TEST_CASE(literal_rewrite_pooling_test)
auto* mm = p.get_main_module();
auto input = mm->add_literal(migraphx::literal(s, data));
auto ret = mm->add_instruction(migraphx::op::pooling{mode, {0, 0, 0}, {1, 1, 1}, {3, 4, 5}},
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", mode},
{"padding", {0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}),
input);
mm->add_return({ret});
return p;
......@@ -141,9 +160,10 @@ TEST_CASE(literal_rewrite_pooling_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_literal(migraphx::literal(s, data));
auto rsp = mm->add_instruction(migraphx::op::reshape{{4, -1}}, input);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), input);
auto rdm = mm->add_instruction(op, rsp);
auto ret = mm->add_instruction(migraphx::op::reshape{{2, 2, 1, 1, 1}}, rdm);
auto ret =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1, 1, 1}}}), rdm);
mm->add_return({ret});
return p;
......@@ -160,8 +180,8 @@ TEST_CASE(literal_rewrite_pooling_test)
result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
};
test_rewrite_pooling("max", migraphx::op::reduce_max{{1}});
test_rewrite_pooling("average", migraphx::op::reduce_mean{{1}});
test_rewrite_pooling("max", migraphx::make_op("reduce_max", {{"axes", {1}}}));
test_rewrite_pooling("average", migraphx::make_op("reduce_mean", {{"axes", {1}}}));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/schedule.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
struct unary_op
......@@ -297,8 +298,8 @@ TEST_CASE(zero_record)
auto one = mm->add_literal(1);
auto onep1 = mm->add_instruction(unary_op{}, one);
auto onep2 = mm->add_instruction(unary_op{}, one);
auto onei1 = mm->add_instruction(migraphx::op::identity{}, onep1);
auto onei2 = mm->add_instruction(migraphx::op::identity{}, onep2);
auto onei1 = mm->add_instruction(migraphx::make_op("identity"), onep1);
auto onei2 = mm->add_instruction(migraphx::make_op("identity"), onep2);
auto binary = mm->add_instruction(nary_op{}, onei1, onei2);
t.run_pass(p);
EXPECT(not t.has_stream(one));
......@@ -319,7 +320,7 @@ TEST_CASE(zero_merge1)
auto one = mm->add_literal(1);
auto onep1 = mm->add_instruction(unary_op{}, one);
auto onep2 = mm->add_instruction(unary_op{}, one);
auto binary = mm->add_instruction(migraphx::op::identity{}, onep1, onep2);
auto binary = mm->add_instruction(migraphx::make_op("identity"), onep1, onep2);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
......@@ -339,9 +340,9 @@ TEST_CASE(zero_merge2)
auto one = mm->add_literal(1);
auto onep1 = mm->add_instruction(unary_op{}, one);
auto onep2 = mm->add_instruction(unary_op{}, one);
auto binary = mm->add_instruction(migraphx::op::identity{},
mm->add_instruction(migraphx::op::identity{}, onep1),
mm->add_instruction(migraphx::op::identity{}, onep2));
auto binary = mm->add_instruction(migraphx::make_op("identity"),
mm->add_instruction(migraphx::make_op("identity"), onep1),
mm->add_instruction(migraphx::make_op("identity"), onep2));
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
......@@ -361,7 +362,7 @@ TEST_CASE(zero_merge3)
auto one = mm->add_literal(1);
auto onep1 = mm->add_instruction(unary_op{}, one);
auto onep2 = mm->add_instruction(unary_op{}, one);
auto id = mm->add_instruction(migraphx::op::identity{}, onep1, onep2);
auto id = mm->add_instruction(migraphx::make_op("identity"), onep1, onep2);
auto final = mm->add_instruction(unary_op{}, id);
t.run_pass(p);
EXPECT(not t.has_stream(one));
......@@ -386,9 +387,9 @@ TEST_CASE(zero_merge4)
auto one = mm->add_literal(1);
auto onep1 = mm->add_instruction(unary_op{}, one);
auto onep2 = mm->add_instruction(unary_op{}, one);
auto id = mm->add_instruction(migraphx::op::identity{},
mm->add_instruction(migraphx::op::identity{}, onep1),
mm->add_instruction(migraphx::op::identity{}, onep2));
auto id = mm->add_instruction(migraphx::make_op("identity"),
mm->add_instruction(migraphx::make_op("identity"), onep1),
mm->add_instruction(migraphx::make_op("identity"), onep2));
auto final = mm->add_instruction(unary_op{}, id);
t.run_pass(p);
EXPECT(not t.has_stream(one));
......@@ -811,17 +812,17 @@ TEST_CASE(inception1)
auto i4 = mm->add_literal(2);
auto i7 = mm->add_instruction(nary_op{"i7"}, i1, i4, i3, i2);
auto i8 = mm->add_literal(2);
auto i9 = mm->add_instruction(migraphx::op::identity{}, i8);
auto i9 = mm->add_instruction(migraphx::make_op("identity"), i8);
auto i10 = mm->add_literal(1);
auto i11 = mm->add_instruction(nary_op{"i11"}, i7, i9, i10);
auto i12 = mm->add_literal(2);
auto i13 = mm->add_instruction(migraphx::op::identity{}, i12);
auto i13 = mm->add_instruction(migraphx::make_op("identity"), i12);
auto i14 = mm->add_literal(1);
auto i15 = mm->add_literal(1);
auto i16 = mm->add_literal(2);
auto i17 = mm->add_instruction(nary_op{"i17"}, i11, i16, i15, i13, i14);
auto i18 = mm->add_literal(2);
auto i19 = mm->add_instruction(migraphx::op::identity{}, i18);
auto i19 = mm->add_instruction(migraphx::make_op("identity"), i18);
auto i20 = mm->add_literal(1);
auto i21 = mm->add_literal(1);
auto i22 = mm->add_literal(2);
......@@ -829,13 +830,13 @@ TEST_CASE(inception1)
auto i24 = mm->add_literal(1);
auto i25 = mm->add_instruction(nary_op{"i25"}, i23, i24);
auto i26 = mm->add_literal(2);
auto i27 = mm->add_instruction(migraphx::op::identity{}, i26);
auto i27 = mm->add_instruction(migraphx::make_op("identity"), i26);
auto i28 = mm->add_literal(1);
auto i29 = mm->add_literal(1);
auto i30 = mm->add_literal(2);
auto i31 = mm->add_instruction(nary_op{"i31"}, i25, i30, i29, i27, i28);
auto i32 = mm->add_literal(2);
auto i33 = mm->add_instruction(migraphx::op::identity{}, i32);
auto i33 = mm->add_instruction(migraphx::make_op("identity"), i32);
auto i34 = mm->add_literal(1);
auto i35 = mm->add_literal(1);
auto i36 = mm->add_literal(2);
......@@ -843,53 +844,53 @@ TEST_CASE(inception1)
auto i38 = mm->add_literal(1);
auto i39 = mm->add_instruction(nary_op{"i39"}, i37, i38);
auto i41 = mm->add_literal(2);
auto i42 = mm->add_instruction(migraphx::op::identity{}, i41);
auto i42 = mm->add_instruction(migraphx::make_op("identity"), i41);
auto i43 = mm->add_literal(1);
auto i44 = mm->add_literal(1);
auto i45 = mm->add_literal(2);
auto i48 = mm->add_instruction(nary_op{"i48"}, i39, i45, i44, i42, i43);
auto i49 = mm->add_literal(2);
auto i50 = mm->add_instruction(migraphx::op::identity{}, i49);
auto i50 = mm->add_instruction(migraphx::make_op("identity"), i49);
auto i51 = mm->add_literal(1);
auto i52 = mm->add_literal(1);
auto i53 = mm->add_literal(2);
auto i54 = mm->add_instruction(nary_op{"i54"}, i48, i53, i52, i50, i51);
auto i55 = mm->add_literal(1);
auto i56 = mm->add_instruction(migraphx::op::identity{}, i55);
auto i56 = mm->add_instruction(migraphx::make_op("identity"), i55);
auto i57 = mm->add_literal(2);
auto i58 = mm->add_instruction(migraphx::op::identity{}, i57);
auto i58 = mm->add_instruction(migraphx::make_op("identity"), i57);
auto i59 = mm->add_literal(1);
auto i60 = mm->add_literal(2);
auto i61 = mm->add_instruction(nary_op{"i61"}, i54, i60, i59, i58, i56);
auto i62 = mm->add_literal(2);
auto i63 = mm->add_instruction(migraphx::op::identity{}, i62);
auto i63 = mm->add_instruction(migraphx::make_op("identity"), i62);
auto i64 = mm->add_literal(1);
auto i65 = mm->add_literal(1);
auto i66 = mm->add_literal(2);
auto i69 = mm->add_instruction(nary_op{"i69"}, i39, i66, i65, i63, i64);
auto i70 = mm->add_instruction(migraphx::op::identity{}, i55);
auto i70 = mm->add_instruction(migraphx::make_op("identity"), i55);
auto i71 = mm->add_literal(2);
auto i72 = mm->add_instruction(migraphx::op::identity{}, i71);
auto i72 = mm->add_instruction(migraphx::make_op("identity"), i71);
auto i73 = mm->add_literal(1);
auto i74 = mm->add_literal(2);
auto i75 = mm->add_instruction(nary_op{"i75"}, i69, i74, i73, i72, i70);
auto i77 = mm->add_literal(1);
auto i80 = mm->add_instruction(nary_op{"i80"}, i39, i77);
auto i81 = mm->add_instruction(migraphx::op::identity{}, i55);
auto i81 = mm->add_instruction(migraphx::make_op("identity"), i55);
auto i82 = mm->add_literal(2);
auto i83 = mm->add_instruction(migraphx::op::identity{}, i82);
auto i83 = mm->add_instruction(migraphx::make_op("identity"), i82);
auto i84 = mm->add_literal(1);
auto i85 = mm->add_literal(2);
auto i86 = mm->add_instruction(nary_op{"i86"}, i80, i85, i84, i83, i81);
auto i88 = mm->add_instruction(migraphx::op::identity{}, i55);
auto i88 = mm->add_instruction(migraphx::make_op("identity"), i55);
auto i89 = mm->add_literal(2);
auto i90 = mm->add_instruction(migraphx::op::identity{}, i89);
auto i90 = mm->add_instruction(migraphx::make_op("identity"), i89);
auto i91 = mm->add_literal(1);
auto i92 = mm->add_literal(2);
auto i94 = mm->add_instruction(nary_op{"i94"}, i39, i92, i91, i90, i88);
auto i96 = mm->add_instruction(migraphx::op::identity{}, i55, i94, i75, i61, i86);
auto i96 = mm->add_instruction(migraphx::make_op("identity"), i55, i94, i75, i61, i86);
auto i97 = mm->add_literal(2);
auto i98 = mm->add_instruction(migraphx::op::identity{}, i97);
auto i98 = mm->add_instruction(migraphx::make_op("identity"), i97);
auto i99 = mm->add_literal(3);
auto i100 = mm->add_literal(1);
auto i101 = mm->add_literal(2);
......
#include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/op/add.hpp>
#include "test.hpp"
#include <migraphx/make_op.hpp>
#include <cstdio>
migraphx::program create_program()
......@@ -12,7 +13,7 @@ migraphx::program create_program()
auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto two = mm->add_literal(2);
auto add = mm->add_instruction(migraphx::op::add{}, x, two);
auto add = mm->add_instruction(migraphx::make_op("add"), x, two);
mm->add_return({add});
return p;
}
......
......@@ -6,6 +6,8 @@
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p)
......@@ -23,9 +25,9 @@ TEST_CASE(simplify_add1)
auto y = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, one);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, two);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, one);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, two);
auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
......@@ -37,9 +39,9 @@ TEST_CASE(simplify_add1)
auto y = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum2, sum1);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
auto sum2 = mm2->add_instruction(migraphx::make_op("add"), x, y);
auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum2, sum1);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
......@@ -54,9 +56,9 @@ TEST_CASE(simplify_add2)
auto y = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, x);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, two, y);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, x);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), two, y);
auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
......@@ -68,9 +70,9 @@ TEST_CASE(simplify_add2)
auto y = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum2, sum1);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
auto sum2 = mm2->add_instruction(migraphx::make_op("add"), x, y);
auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum2, sum1);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
......@@ -84,9 +86,9 @@ TEST_CASE(simplify_add3)
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, x);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, one, two);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, x);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), one, two);
auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
......@@ -97,9 +99,9 @@ TEST_CASE(simplify_add3)
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, one, sum1);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, x, sum2);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
auto sum2 = mm2->add_instruction(migraphx::make_op("add"), one, sum1);
auto sum3 = mm2->add_instruction(migraphx::make_op("add"), x, sum2);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
......@@ -119,9 +121,9 @@ TEST_CASE(simplify_add_broadcast1)
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal({inner, {2, 2}});
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
......@@ -133,10 +135,10 @@ TEST_CASE(simplify_add_broadcast1)
auto y = mm2->add_parameter("y", outer);
auto one = mm2->add_literal({inner, {1, 1}});
auto two = mm2->add_literal({inner, {2, 2}});
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
auto sum1b = mm2->add_instruction(b, sum1);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum2, sum1b);
auto sum2 = mm2->add_instruction(migraphx::make_op("add"), x, y);
auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum2, sum1b);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
......@@ -155,9 +157,9 @@ TEST_CASE(simplify_add_broadcast2)
auto one = mm->add_literal({inner, {1, 1}});
auto oneb = mm->add_instruction(b, one);
auto two = mm->add_literal({outer, {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}});
auto sum1 = mm->add_instruction(migraphx::op::add{}, x, y);
auto sum2 = mm->add_instruction(migraphx::op::add{}, oneb, two);
auto sum3 = mm->add_instruction(migraphx::op::add{}, sum2, sum1);
auto sum1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto sum2 = mm->add_instruction(migraphx::make_op("add"), oneb, two);
auto sum3 = mm->add_instruction(migraphx::make_op("add"), sum2, sum1);
mm->add_instruction(pass_op{}, sum3);
return p;
};
......@@ -179,9 +181,9 @@ void simplify_add4()
auto y = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, x);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, sum1, y);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum2, two);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, x);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), sum1, y);
auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum2, two);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
......@@ -193,9 +195,9 @@ void simplify_add4()
auto y = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum2, sum1);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
auto sum2 = mm2->add_instruction(migraphx::make_op("add"), x, y);
auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum2, sum1);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
......@@ -208,10 +210,15 @@ TEST_CASE(simplify_mul_conv1)
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w = mm->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = mm->add_instruction(migraphx::op::convolution{{1, 1}, {2, 2}, {1, 1}}, x, w);
auto a = mm->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto b = mm->add_instruction(migraphx::op::broadcast{1, {1, 256, 14, 14}}, a);
auto mul = mm->add_instruction(migraphx::op::mul{}, conv, b);
auto conv = mm->add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto a = mm->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto b = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 256, 14, 14}}}), a);
auto mul = mm->add_instruction(migraphx::make_op("mul"), conv, b);
mm->add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
run_pass(p);
......@@ -228,13 +235,16 @@ TEST_CASE(simplify_mul_slice_conv1)
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = mm1->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto conv = mm1->add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto conv = mm1->add_instruction(migraphx::make_op("convolution"), x, w);
auto slice1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
auto mul = mm1->add_instruction(migraphx::op::mul{}, slice1, b);
auto slice2 = mm1->add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
auto add = mm1->add_instruction(migraphx::op::add{}, mul, slice2);
auto b = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a);
auto mul = mm1->add_instruction(migraphx::make_op("mul"), slice1, b);
auto slice2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
auto add = mm1->add_instruction(migraphx::make_op("add"), mul, slice2);
mm1->add_instruction(pass_op{}, add);
}
run_pass(p1);
......@@ -245,16 +255,22 @@ TEST_CASE(simplify_mul_slice_conv1)
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = mm2->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto wslice1 = mm2->add_instruction(migraphx::op::slice{{0}, {0}, {384}}, w);
auto wslice1 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w);
auto a = mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = mm2->add_instruction(migraphx::op::broadcast{0, {384, 1024, 1, 1}}, a);
auto mul = mm2->add_instruction(migraphx::op::mul{}, b, wslice1);
auto wslice2 = mm2->add_instruction(migraphx::op::slice{{0}, {384}, {768}}, w);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, mul, wslice2);
auto conv = mm2->add_instruction(migraphx::op::convolution{}, x, concat);
auto slice1 = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto slice2 = mm2->add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
auto add = mm2->add_instruction(migraphx::op::add{}, slice1, slice2);
auto b = mm2->add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", {384, 1024, 1, 1}}}), a);
auto mul = mm2->add_instruction(migraphx::make_op("mul"), b, wslice1);
auto wslice2 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w);
auto concat =
mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), mul, wslice2);
auto conv = mm2->add_instruction(migraphx::make_op("convolution"), x, concat);
auto slice1 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto slice2 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
auto add = mm2->add_instruction(migraphx::make_op("add"), slice1, slice2);
mm2->add_instruction(pass_op{}, add);
}
EXPECT(p1 == p2);
......@@ -268,13 +284,16 @@ TEST_CASE(simplify_mul_slice_conv_overlapping_slice)
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = mm1->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto conv = mm1->add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto conv = mm1->add_instruction(migraphx::make_op("convolution"), x, w);
auto slice1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
auto mul = mm1->add_instruction(migraphx::op::mul{}, slice1, b);
auto slice2 = mm1->add_instruction(migraphx::op::slice{{1}, {383}, {767}}, conv);
auto add = mm1->add_instruction(migraphx::op::add{}, mul, slice2);
auto b = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a);
auto mul = mm1->add_instruction(migraphx::make_op("mul"), slice1, b);
auto slice2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {383}}, {"ends", {767}}}), conv);
auto add = mm1->add_instruction(migraphx::make_op("add"), mul, slice2);
mm1->add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
......@@ -290,15 +309,17 @@ TEST_CASE(simplify_mul_slice_conv_not_all_slice)
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = mm1->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto conv = mm1->add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto conv = mm1->add_instruction(migraphx::make_op("convolution"), x, w);
auto slice1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
auto mul = mm1->add_instruction(migraphx::op::mul{}, slice1, b);
auto b = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a);
auto mul = mm1->add_instruction(migraphx::make_op("mul"), slice1, b);
auto c = mm1->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {1, 768, 17, 17}}));
auto add = mm1->add_instruction(migraphx::op::add{}, conv, c);
auto concat = mm1->add_instruction(migraphx::op::concat{1}, mul, add);
auto add = mm1->add_instruction(migraphx::make_op("add"), conv, c);
auto concat = mm1->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), mul, add);
mm1->add_instruction(pass_op{}, concat);
}
migraphx::program p2 = p1;
......@@ -314,8 +335,8 @@ TEST_CASE(simplify_mul_add)
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2);
auto sum = mm1->add_instruction(migraphx::op::add{}, one, x);
auto mul = mm1->add_instruction(migraphx::op::mul{}, sum, two);
auto sum = mm1->add_instruction(migraphx::make_op("add"), one, x);
auto mul = mm1->add_instruction(migraphx::make_op("mul"), sum, two);
mm1->add_instruction(pass_op{}, mul);
}
run_pass(p1);
......@@ -326,9 +347,9 @@ TEST_CASE(simplify_mul_add)
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto mul1 = mm2->add_instruction(migraphx::op::mul{}, two, x);
auto mul2 = mm2->add_instruction(migraphx::op::mul{}, two, one);
auto sum = mm2->add_instruction(migraphx::op::add{}, mul1, mul2);
auto mul1 = mm2->add_instruction(migraphx::make_op("mul"), two, x);
auto mul2 = mm2->add_instruction(migraphx::make_op("mul"), two, one);
auto sum = mm2->add_instruction(migraphx::make_op("add"), mul1, mul2);
mm2->add_instruction(pass_op{}, sum);
}
EXPECT(p1 == p2);
......@@ -344,7 +365,7 @@ TEST_CASE(simplify_inner_broadcast)
auto y = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto xb = mm1->add_instruction(b, x);
auto yb = mm1->add_instruction(b, y);
auto sum = mm1->add_instruction(migraphx::op::add{}, xb, yb);
auto sum = mm1->add_instruction(migraphx::make_op("add"), xb, yb);
mm1->add_instruction(pass_op{}, sum);
}
run_pass(p1);
......@@ -354,7 +375,7 @@ TEST_CASE(simplify_inner_broadcast)
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto sum = mm2->add_instruction(migraphx::op::add{}, x, y);
auto sum = mm2->add_instruction(migraphx::make_op("add"), x, y);
auto sumb = mm2->add_instruction(b, sum);
mm2->add_instruction(pass_op{}, sumb);
}
......@@ -371,9 +392,9 @@ TEST_CASE(simplify_add_conv1)
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
......@@ -392,9 +413,10 @@ TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides)
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {3, 3}}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {3, 3}}}), y, v);
auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
......@@ -414,9 +436,10 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides1)
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v);
auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
......@@ -435,9 +458,10 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides2)
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
auto conv1 = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), x, w);
auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
......@@ -456,9 +480,10 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides_odd)
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 54, 165, 165}});
auto v =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {54, 54, 1, 1}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v);
auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
......@@ -477,9 +502,10 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1)
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
auto conv1 = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 1}}}), x, w);
auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
......@@ -499,9 +525,10 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 14}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 1}}}), y, v);
auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
......@@ -516,16 +543,17 @@ TEST_CASE(simplify_concat_add_relu)
auto s = migraphx::shape{migraphx::shape::int32_type, {1}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal({s, {1}});
auto two = mm1->add_literal({s, {2}});
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, one);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, two);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto concat = mm1->add_instruction(migraphx::op::concat{0}, relu1, relu2);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal({s, {1}});
auto two = mm1->add_literal({s, {2}});
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, one);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, two);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto concat =
mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), relu1, relu2);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
......@@ -537,10 +565,10 @@ TEST_CASE(simplify_concat_add_relu)
auto y = mm2->add_parameter("y", s);
auto one = mm2->add_literal({s, {1}});
auto two = mm2->add_literal({s, {2}});
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, x, y);
auto concat2 = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto sum = mm2->add_instruction(migraphx::op::add{}, concat1, concat2);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
auto sum = mm2->add_instruction(migraphx::make_op("add"), concat1, concat2);
auto relu = mm2->add_instruction(migraphx::make_op("relu"), sum);
mm2->add_instruction(pass_op{}, relu);
}
EXPECT(p1 == p2);
......@@ -551,17 +579,18 @@ TEST_CASE(simplify_concat_add_relu_partial)
auto s = migraphx::shape{migraphx::shape::int32_type, {1}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal({s, {1}});
auto two = mm1->add_literal({s, {2}});
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, one);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, two);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, x, y);
auto concat = mm1->add_instruction(migraphx::op::concat{0}, sum3, relu1, relu2);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal({s, {1}});
auto two = mm1->add_literal({s, {2}});
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, one);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, two);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto sum3 = mm1->add_instruction(migraphx::make_op("add"), x, y);
auto concat =
mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), sum3, relu1, relu2);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
......@@ -573,12 +602,12 @@ TEST_CASE(simplify_concat_add_relu_partial)
auto y = mm2->add_parameter("y", s);
auto one = mm2->add_literal({s, {1}});
auto two = mm2->add_literal({s, {2}});
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, x, y);
auto concat2 = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, concat1, concat2);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, sum2, relu);
auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), concat1, concat2);
auto relu = mm2->add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = mm2->add_instruction(migraphx::make_op("add"), x, y);
auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), sum2, relu);
mm2->add_instruction(pass_op{}, concat);
}
EXPECT(p1.sort() == p2.sort());
......@@ -589,16 +618,17 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum = mm1->add_instruction(migraphx::op::add{}, x, y);
auto concat = mm1->add_instruction(migraphx::op::concat{1}, sum, oneb, twob);
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum = mm1->add_instruction(migraphx::make_op("add"), x, y);
auto concat =
mm1->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), sum, oneb, twob);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
......@@ -611,10 +641,11 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
auto y = mm2->add_parameter("y", s);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
auto concatb = mm2->add_instruction(b, concat1);
auto sum = mm2->add_instruction(migraphx::op::add{}, x, y);
auto concat2 = mm2->add_instruction(migraphx::op::concat{1}, sum, concatb);
auto sum = mm2->add_instruction(migraphx::make_op("add"), x, y);
auto concat2 =
mm2->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), sum, concatb);
mm2->add_instruction(pass_op{}, concat2);
}
EXPECT(p1.sort() == p2.sort());
......@@ -625,19 +656,20 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto concat = mm1->add_instruction(migraphx::op::concat{1}, relu1, relu2);
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto concat =
mm1->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), relu1, relu2);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
......@@ -650,11 +682,11 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
auto y = mm2->add_parameter("y", s);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat1 = mm2->add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
auto concat2b = mm2->add_instruction(b, concat2);
auto sum = mm2->add_instruction(migraphx::op::add{}, concat1, concat2b);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
auto sum = mm2->add_instruction(migraphx::make_op("add"), concat1, concat2b);
auto relu = mm2->add_instruction(migraphx::make_op("relu"), sum);
mm2->add_instruction(pass_op{}, relu);
}
EXPECT(p1 == p2);
......@@ -665,19 +697,20 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto concat = mm1->add_instruction(migraphx::op::concat{0}, relu1, relu2);
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto concat =
mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), relu1, relu2);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
......@@ -692,10 +725,10 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
auto oneb = mm2->add_instruction(b, one);
auto two = mm2->add_literal(2);
auto twob = mm2->add_instruction(b, two);
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, x, y);
auto concat2 = mm2->add_instruction(migraphx::op::concat{0}, oneb, twob);
auto sum = mm2->add_instruction(migraphx::op::add{}, concat1, concat2);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), oneb, twob);
auto sum = mm2->add_instruction(migraphx::make_op("add"), concat1, concat2);
auto relu = mm2->add_instruction(migraphx::make_op("relu"), sum);
mm2->add_instruction(pass_op{}, relu);
}
EXPECT(p1 == p2);
......@@ -708,7 +741,7 @@ TEST_CASE(simplify_div_const)
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = mm1->add_literal(2);
mm1->add_instruction(migraphx::op::div{}, x, two);
mm1->add_instruction(migraphx::make_op("div"), x, two);
}
run_pass(p1);
......@@ -717,8 +750,8 @@ TEST_CASE(simplify_div_const)
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = mm2->add_literal(2);
auto recip = mm2->insert_instruction(std::next(two), migraphx::op::recip{}, two);
mm2->add_instruction(migraphx::op::mul{}, x, recip);
auto recip = mm2->insert_instruction(std::next(two), migraphx::make_op("recip"), two);
mm2->add_instruction(migraphx::make_op("mul"), x, recip);
}
EXPECT(p1 == p2);
}
......@@ -730,7 +763,7 @@ TEST_CASE(simplify_sub_const)
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = mm1->add_literal(2);
mm1->add_instruction(migraphx::op::sub{}, x, two);
mm1->add_instruction(migraphx::make_op("sub"), x, two);
}
run_pass(p1);
......@@ -739,8 +772,8 @@ TEST_CASE(simplify_sub_const)
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = mm2->add_literal(2);
auto neg = mm2->insert_instruction(std::next(two), migraphx::op::neg{}, two);
mm2->add_instruction(migraphx::op::add{}, x, neg);
auto neg = mm2->insert_instruction(std::next(two), migraphx::make_op("neg"), two);
mm2->add_instruction(migraphx::make_op("add"), x, neg);
}
EXPECT(p1 == p2);
}
......@@ -751,8 +784,8 @@ TEST_CASE(simplify_rsqrt)
{
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto sqrt = mm1->add_instruction(migraphx::op::sqrt{}, x);
mm1->add_instruction(migraphx::op::recip{}, sqrt);
auto sqrt = mm1->add_instruction(migraphx::make_op("sqrt"), x);
mm1->add_instruction(migraphx::make_op("recip"), sqrt);
}
run_pass(p1);
......@@ -760,7 +793,7 @@ TEST_CASE(simplify_rsqrt)
{
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
mm2->add_instruction(migraphx::op::rsqrt{}, x);
mm2->add_instruction(migraphx::make_op("rsqrt"), x);
}
EXPECT(p1 == p2);
}
......@@ -771,10 +804,10 @@ TEST_CASE(simplify_rsqrt_multi_use)
{
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto sqrt = mm1->add_instruction(migraphx::op::sqrt{}, x);
auto add = mm1->add_instruction(migraphx::op::add{}, sqrt, sqrt);
auto rsqrt = mm1->add_instruction(migraphx::op::recip{}, sqrt);
mm1->add_instruction(migraphx::op::add{}, rsqrt, add);
auto sqrt = mm1->add_instruction(migraphx::make_op("sqrt"), x);
auto add = mm1->add_instruction(migraphx::make_op("add"), sqrt, sqrt);
auto rsqrt = mm1->add_instruction(migraphx::make_op("recip"), sqrt);
mm1->add_instruction(migraphx::make_op("add"), rsqrt, add);
}
migraphx::program p2{p1};
......@@ -791,12 +824,16 @@ TEST_CASE(simplify_slice_concat)
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto xslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {128}}, x);
auto xslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {128}, {256}}, x);
auto yslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {128}}, y);
auto yslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {128}, {256}}, y);
auto concat =
mm1->add_instruction(migraphx::op::concat{0}, xslice1, xslice2, yslice1, yslice2);
auto xslice1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {128}}}), x);
auto xslice2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {128}}, {"ends", {256}}}), x);
auto yslice1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {128}}}), y);
auto yslice2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {128}}, {"ends", {256}}}), y);
auto concat = mm1->add_instruction(
migraphx::make_op("concat", {{"axis", 0}}), xslice1, xslice2, yslice1, yslice2);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
......@@ -806,7 +843,7 @@ TEST_CASE(simplify_slice_concat)
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", s);
auto y = mm2->add_parameter("y", s);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, x, y);
auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
mm2->add_instruction(pass_op{}, concat);
}
EXPECT(p1 == p2);
......@@ -821,14 +858,25 @@ TEST_CASE(simplify_slice_concat_non_uniform)
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto xslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {64}}, x);
auto xslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {64}, {192}}, x);
auto xslice3 = mm1->add_instruction(migraphx::op::slice{{0}, {192}, {256}}, x);
auto yslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {64}}, y);
auto yslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {64}, {192}}, y);
auto yslice3 = mm1->add_instruction(migraphx::op::slice{{0}, {192}, {256}}, y);
auto concat = mm1->add_instruction(
migraphx::op::concat{0}, xslice1, xslice2, xslice3, yslice1, yslice2, yslice3);
auto xslice1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), x);
auto xslice2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), x);
auto xslice3 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), x);
auto yslice1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), y);
auto yslice2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), y);
auto yslice3 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), y);
auto concat = mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}),
xslice1,
xslice2,
xslice3,
yslice1,
yslice2,
yslice3);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
......@@ -838,7 +886,7 @@ TEST_CASE(simplify_slice_concat_non_uniform)
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", s);
auto y = mm2->add_parameter("y", s);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, x, y);
auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
mm2->add_instruction(pass_op{}, concat);
}
......@@ -854,14 +902,25 @@ TEST_CASE(simplify_slice_concat_flipped)
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto xslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {64}}, x);
auto xslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {192}, {256}}, x);
auto xslice3 = mm1->add_instruction(migraphx::op::slice{{0}, {64}, {192}}, x);
auto yslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {64}}, y);
auto yslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {192}, {256}}, y);
auto yslice3 = mm1->add_instruction(migraphx::op::slice{{0}, {64}, {192}}, y);
auto concat = mm1->add_instruction(
migraphx::op::concat{0}, xslice1, xslice2, xslice3, yslice1, yslice2, yslice3);
auto xslice1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), x);
auto xslice2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), x);
auto xslice3 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), x);
auto yslice1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), y);
auto yslice2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), y);
auto yslice3 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), y);
auto concat = mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}),
xslice1,
xslice2,
xslice3,
yslice1,
yslice2,
yslice3);
mm1->add_instruction(pass_op{}, concat);
}
migraphx::program p2 = p1;
......@@ -878,17 +937,19 @@ TEST_CASE(simplify_split_add_relu)
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto x = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto add = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
mm1->add_instruction(pass_op{}, add);
}
run_pass(p1);
......@@ -900,13 +961,15 @@ TEST_CASE(simplify_split_add_relu)
auto input = mm2->add_parameter("input", s);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
auto concatb = mm2->add_instruction(b, concat);
auto sum = mm2->add_instruction(migraphx::op::add{}, input, concatb);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
auto x = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
auto y = mm2->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto add = mm2->add_instruction(migraphx::op::add{}, x, y);
auto sum = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
auto relu = mm2->add_instruction(migraphx::make_op("relu"), sum);
auto x = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);
auto y = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);
auto add = mm2->add_instruction(migraphx::make_op("add"), x, y);
mm2->add_instruction(pass_op{}, add);
}
EXPECT(p1.sort() == p2.sort());
......@@ -917,23 +980,25 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto r = migraphx::op::reshape{{3, 4}};
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto r = migraphx::op::reshape{{3, 4}};
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto reshape1 = mm1->add_instruction(r, relu1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto reshape2 = mm1->add_instruction(r, relu2);
auto add = mm1->add_instruction(migraphx::op::add{}, reshape1, reshape2);
auto add = mm1->add_instruction(migraphx::make_op("add"), reshape1, reshape2);
mm1->add_instruction(pass_op{}, add);
}
run_pass(p1);
......@@ -945,14 +1010,16 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto input = mm2->add_parameter("input", s);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
auto concatb = mm2->add_instruction(b, concat);
auto sum = mm2->add_instruction(migraphx::op::add{}, input, concatb);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
auto rsp = mm2->add_instruction(migraphx::op::reshape{{3, 8}}, relu);
auto slc1 = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {4}}, rsp);
auto slc2 = mm2->add_instruction(migraphx::op::slice{{1}, {4}, {8}}, rsp);
auto add = mm2->add_instruction(migraphx::op::add{}, slc1, slc2);
auto sum = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
auto relu = mm2->add_instruction(migraphx::make_op("relu"), sum);
auto rsp = mm2->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 8}}}), relu);
auto slc1 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), rsp);
auto slc2 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {8}}}), rsp);
auto add = mm2->add_instruction(migraphx::make_op("add"), slc1, slc2);
mm2->add_instruction(pass_op{}, add);
}
EXPECT(p1.sort() == p2.sort());
......@@ -963,22 +1030,26 @@ TEST_CASE(simplify_slice_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto r = migraphx::op::reshape{{3, 2, 4}};
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{3}, {0}, {1}}, input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(migraphx::op::broadcast{1, {3, 1, 4, 2}}, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(migraphx::op::broadcast{3, {3, 2, 4, 1}}, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto* mm1 = p1.get_main_module();
auto r = migraphx::op::reshape{{3, 2, 4}};
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {1}}}), input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {3, 1, 4, 2}}}), one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 3}, {"dims", {3, 2, 4, 1}}}), two);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto reshape1 = mm1->add_instruction(r, relu1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto reshape2 = mm1->add_instruction(r, relu2);
auto add = mm1->add_instruction(migraphx::op::add{}, reshape1, reshape2);
auto add = mm1->add_instruction(migraphx::make_op("add"), reshape1, reshape2);
mm1->add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
......@@ -995,17 +1066,19 @@ TEST_CASE(simplify_slice_missing_begining_slice)
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {2}, {3}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto x = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
auto y = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto add = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
mm1->add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
......@@ -1022,17 +1095,19 @@ TEST_CASE(simplify_slice_missing_middle_slice)
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {2}, {3}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto x = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
auto y = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto add = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
mm1->add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
......@@ -1049,17 +1124,19 @@ TEST_CASE(simplify_slice_missing_end_slice)
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto x = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto add = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
mm1->add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
......@@ -1073,20 +1150,23 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto concat = mm1->add_instruction(migraphx::op::concat{1}, relu1, relu2);
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto concat =
mm1->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), relu1, relu2);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
......@@ -1098,10 +1178,10 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
auto input = mm2->add_parameter("input", s);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
auto concatb = mm2->add_instruction(b, concat);
auto sum = mm2->add_instruction(migraphx::op::add{}, input, concatb);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
auto sum = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
auto relu = mm2->add_instruction(migraphx::make_op("relu"), sum);
mm2->add_instruction(pass_op{}, relu);
}
EXPECT(p1.sort() == p2.sort());
......@@ -1115,17 +1195,21 @@ TEST_CASE(simplify_split_add_relu_multi_axes)
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4, 3}};
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1, 3}, {0, 0}, {1, 3}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1, 3}, {1, 3}, {2, 6}}, input);
auto x = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}),
input);
auto y = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {1, 3}}, {"ends", {2, 6}}}),
input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto add = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
mm1->add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
......@@ -1141,38 +1225,43 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto x = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add1 = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
auto add2 = mm1->add_instruction(migraphx::op::add{}, x, add1);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto add1 = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
auto add2 = mm1->add_instruction(migraphx::make_op("add"), x, add1);
mm1->add_instruction(pass_op{}, add2);
}
run_pass(p1);
migraphx::program p2;
{
auto* mm2 = p2.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto input = mm2->add_parameter("input", s);
auto slice = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto* mm2 = p2.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto input = mm2->add_parameter("input", s);
auto slice = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
auto concatb = mm2->add_instruction(b, concat);
auto sum = mm2->add_instruction(migraphx::op::add{}, input, concatb);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
auto x = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
auto y = mm2->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto add1 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto add2 = mm2->add_instruction(migraphx::op::add{}, slice, add1);
auto sum = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
auto relu = mm2->add_instruction(migraphx::make_op("relu"), sum);
auto x = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);
auto y = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);
auto add1 = mm2->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm2->add_instruction(migraphx::make_op("add"), slice, add1);
mm2->add_instruction(pass_op{}, add2);
}
EXPECT(p1.sort() == p2.sort());
......@@ -1186,40 +1275,45 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto z = mm1->add_instruction(migraphx::op::relu{}, x);
auto x = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto z = mm1->add_instruction(migraphx::make_op("relu"), x);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add1 = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
auto add2 = mm1->add_instruction(migraphx::op::add{}, z, add1);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
auto add1 = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
auto add2 = mm1->add_instruction(migraphx::make_op("add"), z, add1);
mm1->add_instruction(pass_op{}, add2);
}
run_pass(p1);
migraphx::program p2;
{
auto* mm2 = p2.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto input = mm2->add_parameter("input", s);
auto slice = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto z = mm2->add_instruction(migraphx::op::relu{}, slice);
auto* mm2 = p2.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto input = mm2->add_parameter("input", s);
auto slice = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto z = mm2->add_instruction(migraphx::make_op("relu"), slice);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
auto concatb = mm2->add_instruction(b, concat);
auto sum = mm2->add_instruction(migraphx::op::add{}, input, concatb);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
auto x = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
auto y = mm2->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto add1 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto add2 = mm2->add_instruction(migraphx::op::add{}, z, add1);
auto sum = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
auto relu = mm2->add_instruction(migraphx::make_op("relu"), sum);
auto x = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);
auto y = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);
auto add1 = mm2->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm2->add_instruction(migraphx::make_op("add"), z, add1);
mm2->add_instruction(pass_op{}, add2);
}
EXPECT(p1.sort() == p2.sort());
......@@ -1232,9 +1326,11 @@ TEST_CASE(simplify_split_between_add)
{
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto sum = mm1->add_instruction(migraphx::op::add{}, x, y);
auto x = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto sum = mm1->add_instruction(migraphx::make_op("add"), x, y);
mm1->add_instruction(pass_op{}, sum);
}
migraphx::program p2 = p1;
......@@ -1251,9 +1347,9 @@ TEST_CASE(simplify_dot_horiz)
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(s, 0));
auto b = mm1->add_literal(migraphx::generate_literal(s, 1));
auto x = mm1->add_instruction(migraphx::op::dot{}, input, a);
auto y = mm1->add_instruction(migraphx::op::dot{}, input, b);
auto sum = mm1->add_instruction(migraphx::op::add{}, x, y);
auto x = mm1->add_instruction(migraphx::make_op("dot"), input, a);
auto y = mm1->add_instruction(migraphx::make_op("dot"), input, b);
auto sum = mm1->add_instruction(migraphx::make_op("add"), x, y);
mm1->add_instruction(pass_op{}, sum);
}
run_pass(p1);
......@@ -1264,11 +1360,13 @@ TEST_CASE(simplify_dot_horiz)
auto input = mm2->add_parameter("input", s);
auto a = mm2->add_literal(migraphx::generate_literal(s, 0));
auto b = mm2->add_literal(migraphx::generate_literal(s, 1));
auto concat = mm2->add_instruction(migraphx::op::concat{2}, a, b);
auto dot = mm2->add_instruction(migraphx::op::dot{}, input, concat);
auto x = mm2->add_instruction(migraphx::op::slice{{2}, {0}, {2}}, dot);
auto y = mm2->add_instruction(migraphx::op::slice{{2}, {2}, {4}}, dot);
auto sum = mm2->add_instruction(migraphx::op::add{}, x, y);
auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b);
auto dot = mm2->add_instruction(migraphx::make_op("dot"), input, concat);
auto x = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot);
auto y = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot);
auto sum = mm2->add_instruction(migraphx::make_op("add"), x, y);
mm2->add_instruction(pass_op{}, sum);
}
EXPECT(p1.sort() == p2.sort());
......@@ -1282,9 +1380,9 @@ TEST_CASE(simplify_dot_horiz_same_constant)
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(s, 0));
auto x = mm1->add_instruction(migraphx::op::dot{}, input, a);
auto y = mm1->add_instruction(migraphx::op::dot{}, input, a);
auto sum = mm1->add_instruction(migraphx::op::add{}, x, y);
auto x = mm1->add_instruction(migraphx::make_op("dot"), input, a);
auto y = mm1->add_instruction(migraphx::make_op("dot"), input, a);
auto sum = mm1->add_instruction(migraphx::make_op("add"), x, y);
mm1->add_instruction(pass_op{}, sum);
}
run_pass(p1);
......@@ -1294,11 +1392,13 @@ TEST_CASE(simplify_dot_horiz_same_constant)
auto* mm2 = p2.get_main_module();
auto input = mm2->add_parameter("input", s);
auto a = mm2->add_literal(migraphx::generate_literal(s, 0));
auto concat = mm2->add_instruction(migraphx::op::concat{2}, a, a);
auto dot = mm2->add_instruction(migraphx::op::dot{}, input, concat);
auto x = mm2->add_instruction(migraphx::op::slice{{2}, {0}, {2}}, dot);
auto y = mm2->add_instruction(migraphx::op::slice{{2}, {2}, {4}}, dot);
auto sum = mm2->add_instruction(migraphx::op::add{}, x, y);
auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, a);
auto dot = mm2->add_instruction(migraphx::make_op("dot"), input, concat);
auto x = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot);
auto y = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot);
auto sum = mm2->add_instruction(migraphx::make_op("add"), x, y);
mm2->add_instruction(pass_op{}, sum);
}
EXPECT(p1.sort() == p2.sort());
......@@ -1313,9 +1413,9 @@ TEST_CASE(simplify_dot_horiz_flipped)
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(s, 0));
auto b = mm1->add_literal(migraphx::generate_literal(s, 1));
auto x = mm1->add_instruction(migraphx::op::dot{}, input, a);
auto y = mm1->add_instruction(migraphx::op::dot{}, b, input);
auto sum = mm1->add_instruction(migraphx::op::add{}, x, y);
auto x = mm1->add_instruction(migraphx::make_op("dot"), input, a);
auto y = mm1->add_instruction(migraphx::make_op("dot"), b, input);
auto sum = mm1->add_instruction(migraphx::make_op("add"), x, y);
mm1->add_instruction(pass_op{}, sum);
}
......@@ -1334,9 +1434,9 @@ TEST_CASE(simplify_conv_horiz)
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(ws, 0));
auto b = mm1->add_literal(migraphx::generate_literal(ws, 1));
auto x = mm1->add_instruction(migraphx::op::convolution{}, input, a);
auto y = mm1->add_instruction(migraphx::op::convolution{}, input, b);
auto sum = mm1->add_instruction(migraphx::op::add{}, x, y);
auto x = mm1->add_instruction(migraphx::make_op("convolution"), input, a);
auto y = mm1->add_instruction(migraphx::make_op("convolution"), input, b);
auto sum = mm1->add_instruction(migraphx::make_op("add"), x, y);
mm1->add_instruction(pass_op{}, sum);
}
run_pass(p1);
......@@ -1347,11 +1447,13 @@ TEST_CASE(simplify_conv_horiz)
auto input = mm2->add_parameter("input", s);
auto a = mm2->add_literal(migraphx::generate_literal(ws, 0));
auto b = mm2->add_literal(migraphx::generate_literal(ws, 1));
auto concat = mm2->add_instruction(migraphx::op::concat{0}, a, b);
auto conv = mm2->add_instruction(migraphx::op::convolution{}, input, concat);
auto x = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {12}}, conv);
auto y = mm2->add_instruction(migraphx::op::slice{{1}, {12}, {24}}, conv);
auto sum = mm2->add_instruction(migraphx::op::add{}, x, y);
auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b);
auto conv = mm2->add_instruction(migraphx::make_op("convolution"), input, concat);
auto x = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), conv);
auto y = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), conv);
auto sum = mm2->add_instruction(migraphx::make_op("add"), x, y);
mm2->add_instruction(pass_op{}, sum);
}
EXPECT(p1.sort() == p2.sort());
......@@ -1363,14 +1465,22 @@ TEST_CASE(simplify_group_conv_horiz)
auto ws = migraphx::shape{migraphx::shape::int32_type, {32, 1, 7, 7}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto w1 = mm1->add_literal(migraphx::generate_literal(ws, 1));
auto w2 = mm1->add_literal(migraphx::generate_literal(ws, 2));
auto conv1 =
mm1->add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}, 32}, x, w1);
auto conv2 =
mm1->add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}, 32}, x, w2);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto w1 = mm1->add_literal(migraphx::generate_literal(ws, 1));
auto w2 = mm1->add_literal(migraphx::generate_literal(ws, 2));
auto conv1 = mm1->add_instruction(
migraphx::make_op(
"convolution",
{{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}, {"group", 32}}),
x,
w1);
auto conv2 = mm1->add_instruction(
migraphx::make_op(
"convolution",
{{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}, {"group", 32}}),
x,
w2);
mm1->add_instruction(pass_op{}, conv1, conv2);
}
migraphx::program p2 = p1;
......@@ -1392,13 +1502,15 @@ TEST_CASE(simplify_conv_horiz_grouped)
auto b = mm1->add_literal(migraphx::generate_literal(ws1, 1));
auto c = mm1->add_literal(migraphx::generate_literal(ws2, 2));
auto d = mm1->add_literal(migraphx::generate_literal(ws2, 3));
auto convx = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
auto convy = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
auto dotx = mm1->add_instruction(migraphx::op::dot{}, input, c);
auto doty = mm1->add_instruction(migraphx::op::dot{}, input, d);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, convx, convy);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
auto convx =
mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, a);
auto convy =
mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, b);
auto dotx = mm1->add_instruction(migraphx::make_op("dot"), input, c);
auto doty = mm1->add_instruction(migraphx::make_op("dot"), input, d);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), convx, convy);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), dotx, doty);
auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
}
......@@ -1412,17 +1524,22 @@ TEST_CASE(simplify_conv_horiz_grouped)
auto b = mm2->add_literal(migraphx::generate_literal(ws1, 1));
auto c = mm2->add_literal(migraphx::generate_literal(ws2, 2));
auto d = mm2->add_literal(migraphx::generate_literal(ws2, 3));
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, a, b);
auto concat2 = mm2->add_instruction(migraphx::op::concat{3}, c, d);
auto conv = mm2->add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
auto convx = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
auto convy = mm2->add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, convx, convy);
auto dot = mm2->add_instruction(migraphx::op::dot{}, input, concat2);
auto dotx = mm2->add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
auto doty = mm2->add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum2);
auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b);
auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), c, d);
auto conv = mm2->add_instruction(
migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, concat1);
auto convx = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
auto convy = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), convx, convy);
auto dot = mm2->add_instruction(migraphx::make_op("dot"), input, concat2);
auto dotx = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
auto doty = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
auto sum2 = mm2->add_instruction(migraphx::make_op("add"), dotx, doty);
auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum2);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1.sort() == p2.sort());
......@@ -1435,23 +1552,25 @@ TEST_CASE(simplify_conv_horiz_grouped_extra1)
auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(ws1, 0));
auto b = mm1->add_literal(migraphx::generate_literal(ws1, 1));
auto c = mm1->add_literal(migraphx::generate_literal(ws2, 2));
auto d = mm1->add_literal(migraphx::generate_literal(ws2, 3));
auto e = mm1->add_literal(migraphx::generate_literal(s, 4));
auto convx = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
auto convy = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
auto dotx = mm1->add_instruction(migraphx::op::dot{}, input, c);
auto doty = mm1->add_instruction(migraphx::op::dot{}, input, d);
auto sqdiffx = mm1->add_instruction(migraphx::op::sqdiff{}, input, e);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, convx, convy);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, dotx, doty);
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(ws1, 0));
auto b = mm1->add_literal(migraphx::generate_literal(ws1, 1));
auto c = mm1->add_literal(migraphx::generate_literal(ws2, 2));
auto d = mm1->add_literal(migraphx::generate_literal(ws2, 3));
auto e = mm1->add_literal(migraphx::generate_literal(s, 4));
auto convx =
mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, a);
auto convy =
mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, b);
auto dotx = mm1->add_instruction(migraphx::make_op("dot"), input, c);
auto doty = mm1->add_instruction(migraphx::make_op("dot"), input, d);
auto sqdiffx = mm1->add_instruction(migraphx::make_op("sqdiff"), input, e);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), convx, convy);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), dotx, doty);
auto sum3 = sqdiffx;
auto sum4 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = mm1->add_instruction(migraphx::op::add{}, sum4, sum3);
auto sum4 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
auto sum5 = mm1->add_instruction(migraphx::make_op("add"), sum4, sum3);
mm1->add_instruction(pass_op{}, sum5);
}
run_pass(p1);
......@@ -1465,20 +1584,25 @@ TEST_CASE(simplify_conv_horiz_grouped_extra1)
auto c = mm2->add_literal(migraphx::generate_literal(ws2, 2));
auto d = mm2->add_literal(migraphx::generate_literal(ws2, 3));
auto e = mm2->add_literal(migraphx::generate_literal(s, 4));
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, a, b);
auto concat2 = mm2->add_instruction(migraphx::op::concat{3}, c, d);
auto conv = mm2->add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
auto convx = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
auto convy = mm2->add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, convx, convy);
auto dot = mm2->add_instruction(migraphx::op::dot{}, input, concat2);
auto dotx = mm2->add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
auto doty = mm2->add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, dotx, doty);
auto sqdiffx = mm2->add_instruction(migraphx::op::sqdiff{}, input, e);
auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b);
auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), c, d);
auto conv = mm2->add_instruction(
migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, concat1);
auto convx = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
auto convy = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), convx, convy);
auto dot = mm2->add_instruction(migraphx::make_op("dot"), input, concat2);
auto dotx = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
auto doty = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
auto sum2 = mm2->add_instruction(migraphx::make_op("add"), dotx, doty);
auto sqdiffx = mm2->add_instruction(migraphx::make_op("sqdiff"), input, e);
auto sum3 = sqdiffx;
auto sum4 = mm2->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = mm2->add_instruction(migraphx::op::add{}, sum4, sum3);
auto sum4 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum2);
auto sum5 = mm2->add_instruction(migraphx::make_op("add"), sum4, sum3);
mm2->add_instruction(pass_op{}, sum5);
}
EXPECT(p1.sort() == p2.sort());
......@@ -1491,25 +1615,27 @@ TEST_CASE(simplify_conv_horiz_grouped_extra2)
auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(ws1, 0));
auto b = mm1->add_literal(migraphx::generate_literal(ws1, 1));
auto c = mm1->add_literal(migraphx::generate_literal(ws2, 2));
auto d = mm1->add_literal(migraphx::generate_literal(ws2, 3));
auto e = mm1->add_literal(migraphx::generate_literal(s, 4));
auto f = mm1->add_literal(migraphx::generate_literal(s, 5));
auto convx = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
auto convy = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
auto dotx = mm1->add_instruction(migraphx::op::dot{}, input, c);
auto doty = mm1->add_instruction(migraphx::op::dot{}, input, d);
auto sqdiffx = mm1->add_instruction(migraphx::op::sqdiff{}, input, e);
auto sqdiffy = mm1->add_instruction(migraphx::op::sqdiff{}, input, f);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, convx, convy);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sqdiffx, sqdiffy);
auto sum4 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = mm1->add_instruction(migraphx::op::add{}, sum4, sum3);
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(ws1, 0));
auto b = mm1->add_literal(migraphx::generate_literal(ws1, 1));
auto c = mm1->add_literal(migraphx::generate_literal(ws2, 2));
auto d = mm1->add_literal(migraphx::generate_literal(ws2, 3));
auto e = mm1->add_literal(migraphx::generate_literal(s, 4));
auto f = mm1->add_literal(migraphx::generate_literal(s, 5));
auto convx =
mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, a);
auto convy =
mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, b);
auto dotx = mm1->add_instruction(migraphx::make_op("dot"), input, c);
auto doty = mm1->add_instruction(migraphx::make_op("dot"), input, d);
auto sqdiffx = mm1->add_instruction(migraphx::make_op("sqdiff"), input, e);
auto sqdiffy = mm1->add_instruction(migraphx::make_op("sqdiff"), input, f);
auto sum1 = mm1->add_instruction(migraphx::make_op("add"), convx, convy);
auto sum2 = mm1->add_instruction(migraphx::make_op("add"), dotx, doty);
auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sqdiffx, sqdiffy);
auto sum4 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
auto sum5 = mm1->add_instruction(migraphx::make_op("add"), sum4, sum3);
mm1->add_instruction(pass_op{}, sum5);
}
run_pass(p1);
......@@ -1524,21 +1650,26 @@ TEST_CASE(simplify_conv_horiz_grouped_extra2)
auto d = mm2->add_literal(migraphx::generate_literal(ws2, 3));
auto e = mm2->add_literal(migraphx::generate_literal(s, 4));
auto f = mm2->add_literal(migraphx::generate_literal(s, 5));
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, a, b);
auto concat2 = mm2->add_instruction(migraphx::op::concat{3}, c, d);
auto conv = mm2->add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
auto convx = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
auto convy = mm2->add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, convx, convy);
auto dot = mm2->add_instruction(migraphx::op::dot{}, input, concat2);
auto dotx = mm2->add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
auto doty = mm2->add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, dotx, doty);
auto sqdiffx = mm2->add_instruction(migraphx::op::sqdiff{}, input, e);
auto sqdiffy = mm2->add_instruction(migraphx::op::sqdiff{}, input, f);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sqdiffx, sqdiffy);
auto sum4 = mm2->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = mm2->add_instruction(migraphx::op::add{}, sum4, sum3);
auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b);
auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), c, d);
auto conv = mm2->add_instruction(
migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, concat1);
auto convx = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
auto convy = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
auto sum1 = mm2->add_instruction(migraphx::make_op("add"), convx, convy);
auto dot = mm2->add_instruction(migraphx::make_op("dot"), input, concat2);
auto dotx = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
auto doty = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
auto sum2 = mm2->add_instruction(migraphx::make_op("add"), dotx, doty);
auto sqdiffx = mm2->add_instruction(migraphx::make_op("sqdiff"), input, e);
auto sqdiffy = mm2->add_instruction(migraphx::make_op("sqdiff"), input, f);
auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sqdiffx, sqdiffy);
auto sum4 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum2);
auto sum5 = mm2->add_instruction(migraphx::make_op("add"), sum4, sum3);
mm2->add_instruction(pass_op{}, sum5);
}
EXPECT(p1.sort() == p2.sort());
......@@ -1552,21 +1683,26 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = mm1->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto conv = mm1->add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto conv = mm1->add_instruction(migraphx::make_op("convolution"), x, w);
auto slice1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto a1 =
mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a1);
auto mul = mm1->add_instruction(migraphx::op::mul{}, slice1, b1);
auto b1 = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a1);
auto mul = mm1->add_instruction(migraphx::make_op("mul"), slice1, b1);
auto a2 =
mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
auto b2 = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a2);
auto add1 = mm1->add_instruction(migraphx::op::add{}, mul, b2);
auto b2 = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a2);
auto add1 = mm1->add_instruction(migraphx::make_op("add"), mul, b2);
auto a3 =
mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto b3 = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a3);
auto slice2 = mm1->add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
auto add2 = mm1->add_instruction(migraphx::op::add{}, slice2, b3);
auto b3 = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a3);
auto slice2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
auto add2 = mm1->add_instruction(migraphx::make_op("add"), slice2, b3);
mm1->add_instruction(pass_op{}, add1, add2);
}
run_pass(p1);
......@@ -1577,23 +1713,30 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = mm2->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto wslice1 = mm2->add_instruction(migraphx::op::slice{{0}, {0}, {384}}, w);
auto wslice1 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w);
auto a1 =
mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = mm2->add_instruction(migraphx::op::broadcast{0, {384, 1024, 1, 1}}, a1);
auto mul = mm2->add_instruction(migraphx::op::mul{}, b1, wslice1);
auto wslice2 = mm2->add_instruction(migraphx::op::slice{{0}, {384}, {768}}, w);
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, mul, wslice2);
auto conv = mm2->add_instruction(migraphx::op::convolution{}, x, concat1);
auto b1 = mm2->add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", {384, 1024, 1, 1}}}), a1);
auto mul = mm2->add_instruction(migraphx::make_op("mul"), b1, wslice1);
auto wslice2 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w);
auto concat1 =
mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), mul, wslice2);
auto conv = mm2->add_instruction(migraphx::make_op("convolution"), x, concat1);
auto a2 =
mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
auto a3 =
mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto concat2 = mm2->add_instruction(migraphx::op::concat{}, a2, a3);
auto b4 = mm2->add_instruction(migraphx::op::broadcast{1, {1, 768, 17, 17}}, concat2);
auto add = mm2->add_instruction(migraphx::op::add{}, conv, b4);
auto slice1 = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, add);
auto slice2 = mm2->add_instruction(migraphx::op::slice{{1}, {384}, {768}}, add);
auto concat2 = mm2->add_instruction(migraphx::make_op("concat"), a2, a3);
auto b4 = mm2->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 768, 17, 17}}}), concat2);
auto add = mm2->add_instruction(migraphx::make_op("add"), conv, b4);
auto slice1 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), add);
auto slice2 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), add);
mm2->add_instruction(pass_op{}, slice1, slice2);
}
EXPECT(p1.sort() == p2.sort());
......@@ -1607,25 +1750,30 @@ TEST_CASE(reorder_reshape_slice)
auto* mm1 = p1.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = mm1->add_parameter("input", s);
auto slc0 = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
auto slc1 = mm1->add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
auto slc2 = mm1->add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
auto c0 = mm1->add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = mm1->add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = mm1->add_instruction(migraphx::op::contiguous{}, slc2);
auto slc0 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
auto slc1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
input);
auto slc2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input);
auto c0 = mm1->add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = mm1->add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = mm1->add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 10, 64};
auto r0 = mm1->add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = mm1->add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = mm1->add_instruction(migraphx::op::reshape{lens}, c2);
auto r0 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto t0 = mm1->add_instruction(migraphx::op::transpose{perm0}, r0);
auto t1 = mm1->add_instruction(migraphx::op::transpose{perm0}, r1);
auto t2 = mm1->add_instruction(migraphx::op::transpose{perm1}, r2);
auto t0 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r0);
auto t1 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r1);
auto t2 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), r2);
auto sum = mm1->add_instruction(migraphx::op::add{}, t0, t1);
auto ret = mm1->add_instruction(migraphx::op::dot{}, sum, t2);
auto sum = mm1->add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = mm1->add_instruction(migraphx::make_op("dot"), sum, t2);
mm1->add_return({ret});
return p1;
......@@ -1637,18 +1785,21 @@ TEST_CASE(reorder_reshape_slice)
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = mm2->add_parameter("input", s);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 30, 64};
auto r = mm2->add_instruction(migraphx::op::reshape{lens}, input);
auto r = mm2->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = mm2->add_instruction(migraphx::op::slice{{2}, {0}, {10}}, r);
auto slc1 = mm2->add_instruction(migraphx::op::slice{{2}, {10}, {20}}, r);
auto slc2 = mm2->add_instruction(migraphx::op::slice{{2}, {20}, {30}}, r);
auto slc0 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {10}}}), r);
auto slc1 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {10}}, {"ends", {20}}}), r);
auto slc2 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {20}}, {"ends", {30}}}), r);
auto t0 = mm2->add_instruction(migraphx::op::transpose{perm0}, slc0);
auto t1 = mm2->add_instruction(migraphx::op::transpose{perm0}, slc1);
auto t2 = mm2->add_instruction(migraphx::op::transpose{perm1}, slc2);
auto t0 = mm2->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0);
auto t1 = mm2->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1);
auto t2 = mm2->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2);
auto sum = mm2->add_instruction(migraphx::op::add{}, t0, t1);
auto ret = mm2->add_instruction(migraphx::op::dot{}, sum, t2);
auto sum = mm2->add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = mm2->add_instruction(migraphx::make_op("dot"), sum, t2);
mm2->add_return({ret});
return p2;
......@@ -1675,25 +1826,28 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = mm1->add_parameter("input", s);
auto slc0 = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {32}}, input);
auto slc1 = mm1->add_instruction(migraphx::op::slice{{2}, {32}, {64}}, input);
auto slc2 = mm1->add_instruction(migraphx::op::slice{{2}, {64}, {96}}, input);
auto slc0 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), input);
auto slc1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), input);
auto slc2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {96}}}), input);
auto c0 = mm1->add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = mm1->add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = mm1->add_instruction(migraphx::op::contiguous{}, slc2);
auto c0 = mm1->add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = mm1->add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = mm1->add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 32};
auto r0 = mm1->add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = mm1->add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = mm1->add_instruction(migraphx::op::reshape{lens}, c2);
auto r0 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto t0 = mm1->add_instruction(migraphx::op::transpose{perm0}, r0);
auto t1 = mm1->add_instruction(migraphx::op::transpose{perm0}, r1);
auto t2 = mm1->add_instruction(migraphx::op::transpose{perm1}, r2);
auto t0 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r0);
auto t1 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r1);
auto t2 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), r2);
auto sum = mm1->add_instruction(migraphx::op::add{}, t0, t1);
auto ret = mm1->add_instruction(migraphx::op::dot{}, sum, t2);
auto sum = mm1->add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = mm1->add_instruction(migraphx::make_op("dot"), sum, t2);
mm1->add_return({ret});
return p1;
......@@ -1707,16 +1861,19 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = mm->add_parameter("input", s);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 96};
auto rsp = mm->add_instruction(migraphx::op::reshape{lens}, input);
auto slc0 = mm->add_instruction(migraphx::op::slice{{3}, {0}, {32}}, rsp);
auto t0 = mm->add_instruction(migraphx::op::transpose{perm0}, slc0);
auto slc1 = mm->add_instruction(migraphx::op::slice{{3}, {32}, {64}}, rsp);
auto t1 = mm->add_instruction(migraphx::op::transpose{perm0}, slc1);
auto slc2 = mm->add_instruction(migraphx::op::slice{{3}, {64}, {96}}, rsp);
auto t2 = mm->add_instruction(migraphx::op::transpose{perm1}, slc2);
auto sum = mm->add_instruction(migraphx::op::add{}, t0, t1);
auto ret = mm->add_instruction(migraphx::op::dot{}, sum, t2);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp);
auto t0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0);
auto slc1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1);
auto slc2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp);
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2);
auto sum = mm->add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = mm->add_instruction(migraphx::make_op("dot"), sum, t2);
mm->add_return({ret});
return p;
......@@ -1740,21 +1897,24 @@ TEST_CASE(reorder_reshape_slice_move_axis2)
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = mm1->add_parameter("input", s);
auto slc0 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {32}}, input);
auto slc1 = mm1->add_instruction(migraphx::op::slice{{1}, {32}, {64}}, input);
auto slc2 = mm1->add_instruction(migraphx::op::slice{{1}, {64}, {96}}, input);
auto slc0 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {32}}}), input);
auto slc1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {32}}, {"ends", {64}}}), input);
auto slc2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {96}}}), input);
auto c0 = mm1->add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = mm1->add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = mm1->add_instruction(migraphx::op::contiguous{}, slc2);
auto c0 = mm1->add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = mm1->add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = mm1->add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {1, 16, 8, 32};
auto r0 = mm1->add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = mm1->add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = mm1->add_instruction(migraphx::op::reshape{lens}, c2);
auto r0 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto sum = mm1->add_instruction(migraphx::op::add{}, r0, r1);
auto ret = mm1->add_instruction(migraphx::op::mul{}, sum, r2);
auto sum = mm1->add_instruction(migraphx::make_op("add"), r0, r1);
auto ret = mm1->add_instruction(migraphx::make_op("mul"), sum, r2);
mm1->add_return({ret});
return p1;
......@@ -1766,13 +1926,16 @@ TEST_CASE(reorder_reshape_slice_move_axis2)
auto s = migraphx::shape{migraphx::shape::float_type, {128, 96}};
auto input = mm->add_parameter("input", s);
std::vector<int64_t> lens = {1, 16, 8, 96};
auto rsp = mm->add_instruction(migraphx::op::reshape{lens}, input);
auto slc0 = mm->add_instruction(migraphx::op::slice{{3}, {0}, {32}}, rsp);
auto slc1 = mm->add_instruction(migraphx::op::slice{{3}, {32}, {64}}, rsp);
auto slc2 = mm->add_instruction(migraphx::op::slice{{3}, {64}, {96}}, rsp);
auto sum = mm->add_instruction(migraphx::op::add{}, slc0, slc1);
auto ret = mm->add_instruction(migraphx::op::mul{}, sum, slc2);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp);
auto slc1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp);
auto slc2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp);
auto sum = mm->add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = mm->add_instruction(migraphx::make_op("mul"), sum, slc2);
mm->add_return({ret});
return p;
......@@ -1791,21 +1954,24 @@ TEST_CASE(reorder_reshape_slice_not_apply)
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = mm->add_parameter("input", s);
auto slc0 = mm->add_instruction(migraphx::op::slice{{1}, {0}, {32}}, input);
auto slc1 = mm->add_instruction(migraphx::op::slice{{1}, {32}, {64}}, input);
auto slc2 = mm->add_instruction(migraphx::op::slice{{1}, {64}, {96}}, input);
auto slc0 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {32}}}), input);
auto slc1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {32}}, {"ends", {64}}}), input);
auto slc2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {96}}}), input);
auto c0 = mm->add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = mm->add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = mm->add_instruction(migraphx::op::contiguous{}, slc2);
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {1, 16, 16, 16};
auto r0 = mm->add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = mm->add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = mm->add_instruction(migraphx::op::reshape{lens}, c2);
auto r0 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto sum = mm->add_instruction(migraphx::op::add{}, r0, r1);
auto ret = mm->add_instruction(migraphx::op::mul{}, sum, r2);
auto sum = mm->add_instruction(migraphx::make_op("add"), r0, r1);
auto ret = mm->add_instruction(migraphx::make_op("mul"), sum, r2);
mm->add_return({ret});
return p;
......@@ -1826,19 +1992,22 @@ TEST_CASE(reorder_reshape_slice_diff_dims)
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = mm1->add_parameter("input", s);
auto slc0 = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {32}}, input);
auto slc1 = mm1->add_instruction(migraphx::op::slice{{2}, {32}, {64}}, input);
auto slc2 = mm1->add_instruction(migraphx::op::slice{{2}, {64}, {96}}, input);
auto slc0 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), input);
auto slc1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), input);
auto slc2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {96}}}), input);
auto c0 = mm1->add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = mm1->add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = mm1->add_instruction(migraphx::op::contiguous{}, slc2);
auto c0 = mm1->add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = mm1->add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = mm1->add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 32, 3, 32};
std::vector<int64_t> lens1 = {static_cast<int64_t>(batch_size), 48, 2, 32};
auto r0 = mm1->add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = mm1->add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = mm1->add_instruction(migraphx::op::reshape{lens1}, c2);
auto r0 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c2);
mm1->add_return({r0, r1, r2});
......@@ -1864,16 +2033,21 @@ TEST_CASE(reorder_slice_trans)
auto* mm1 = p1.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = mm1->add_parameter("input", s);
auto slc0 = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
auto slc1 = mm1->add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
auto slc2 = mm1->add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
auto t0 = mm1->add_instruction(migraphx::op::transpose{perm}, slc0);
auto t1 = mm1->add_instruction(migraphx::op::transpose{perm}, slc1);
auto t2 = mm1->add_instruction(migraphx::op::transpose{perm}, slc2);
auto sum = mm1->add_instruction(migraphx::op::add{}, t0, t1);
auto ret = mm1->add_instruction(migraphx::op::mul{}, sum, t2);
auto slc0 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
auto slc1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
input);
auto slc2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input);
auto t0 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc0);
auto t1 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc1);
auto t2 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc2);
auto sum = mm1->add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = mm1->add_instruction(migraphx::make_op("mul"), sum, t2);
mm1->add_return({ret});
return p1;
......@@ -1884,14 +2058,17 @@ TEST_CASE(reorder_slice_trans)
auto* mm2 = p2.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = mm2->add_parameter("input", s);
auto r = mm2->add_instruction(migraphx::op::transpose{perm}, input);
auto r = mm2->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), input);
auto slc0 = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {640}}, r);
auto slc1 = mm2->add_instruction(migraphx::op::slice{{1}, {640}, {1280}}, r);
auto slc2 = mm2->add_instruction(migraphx::op::slice{{1}, {1280}, {1920}}, r);
auto slc0 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {640}}}), r);
auto slc1 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {640}}, {"ends", {1280}}}), r);
auto slc2 = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1280}}, {"ends", {1920}}}), r);
auto sum = mm2->add_instruction(migraphx::op::add{}, slc0, slc1);
auto ret = mm2->add_instruction(migraphx::op::mul{}, sum, slc2);
auto sum = mm2->add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = mm2->add_instruction(migraphx::make_op("mul"), sum, slc2);
mm2->add_return({ret});
return p2;
......@@ -1917,16 +2094,21 @@ TEST_CASE(reorder_slice_trans_diff_perm)
std::vector<int64_t> perm0 = {0, 2, 1};
std::vector<int64_t> perm1 = {0, 1, 2};
auto input = mm1->add_parameter("input", s);
auto slc0 = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
auto slc1 = mm1->add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
auto slc2 = mm1->add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
auto t0 = mm1->add_instruction(migraphx::op::transpose{perm0}, slc0);
auto t1 = mm1->add_instruction(migraphx::op::transpose{perm0}, slc1);
auto t2 = mm1->add_instruction(migraphx::op::transpose{perm1}, slc2);
auto sum = mm1->add_instruction(migraphx::op::add{}, t0, t1);
auto ret = mm1->add_instruction(migraphx::op::dot{}, sum, t2);
auto slc0 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
auto slc1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
input);
auto slc2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input);
auto t0 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0);
auto t1 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1);
auto t2 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2);
auto sum = mm1->add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = mm1->add_instruction(migraphx::make_op("dot"), sum, t2);
mm1->add_return({ret});
return p1;
......
......@@ -5,6 +5,10 @@
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p)
......@@ -19,9 +23,9 @@ TEST_CASE(double_contig)
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c1 = mm->add_instruction(migraphx::op::contiguous{}, t1);
auto c2 = mm->add_instruction(migraphx::op::contiguous{}, c1);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1);
auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), c1);
mm->add_return({c2});
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
......@@ -39,8 +43,8 @@ TEST_CASE(double_transpose)
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto t2 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, t1);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t1);
mm->add_return({t2});
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
......@@ -58,10 +62,10 @@ TEST_CASE(double_transpose_contig)
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c1 = mm->add_instruction(migraphx::op::contiguous{}, t1);
auto t2 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, c1);
auto c2 = mm->add_instruction(migraphx::op::contiguous{}, t2);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1);
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), c1);
auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), t2);
mm->add_return({c2});
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
......@@ -79,7 +83,7 @@ TEST_CASE(single_transpose)
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
mm->add_return({t1});
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
......@@ -97,8 +101,8 @@ TEST_CASE(double_transpose_sin_pass)
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
mm->add_instruction(migraphx::op::transpose{{1, 0}}, t1);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t1);
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -116,7 +120,7 @@ TEST_CASE(single_transpose_sin_pass)
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -134,10 +138,10 @@ TEST_CASE(reshape_transpose)
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
auto x = mm->add_parameter("x", s);
auto r1 = mm->add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
auto ct = mm->add_instruction(migraphx::op::contiguous{}, t);
auto r2 = mm->add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
auto r1 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 28, 56, 56}}}), x);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1, 3, 4}}}), r1);
auto ct = mm->add_instruction(migraphx::make_op("contiguous"), t);
auto r2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 112, 56, 56}}}), ct);
mm->add_return({r2});
EXPECT(p.get_output_shapes().back() == s);
auto n = std::distance(p.begin(), p.end());
......@@ -153,8 +157,8 @@ TEST_CASE(transpose_contiguous)
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = mm->add_parameter("x", s);
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = mm->add_instruction(migraphx::op::contiguous{}, t);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), x);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_return({c1});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -170,9 +174,9 @@ TEST_CASE(transpose_double_contiguous)
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = mm->add_parameter("x", s);
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = mm->add_instruction(migraphx::op::contiguous{}, t);
auto c2 = mm->add_instruction(migraphx::op::contiguous{}, c1);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), x);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t);
auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), c1);
mm->add_return({c2});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -189,8 +193,8 @@ TEST_CASE(transpose_partial1)
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = mm->add_parameter("x", s);
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = mm->add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x);
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1);
mm->add_return({t2});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -206,9 +210,9 @@ TEST_CASE(transpose_partial2)
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = mm->add_parameter("x", s);
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = mm->add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x);
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1);
auto t3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t2);
mm->add_return({t3});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -224,10 +228,10 @@ TEST_CASE(transpose_partial3)
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = mm->add_parameter("x", s);
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = mm->add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t4 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x);
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1);
auto t3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t2);
auto t4 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t3);
mm->add_return({t4});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -243,7 +247,7 @@ TEST_CASE(nop_transpose1)
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = mm->add_parameter("x", s);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), x);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -259,10 +263,10 @@ TEST_CASE(nop_transpose2)
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = mm->add_parameter("x", s);
auto t1 = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
auto t2 = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, t1);
auto t3 = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
auto t4 = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), x);
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t1);
auto t3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t2);
auto t4 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t3);
mm->add_instruction(pass_op{}, t4);
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -279,9 +283,9 @@ TEST_CASE(nop_transpose3)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto concat = mm->add_instruction(migraphx::op::concat{3}, x, y);
auto t1 = mm->add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
auto t2 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), x, y);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2, 3}}}), concat);
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t1);
mm->add_return({t2});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -297,7 +301,10 @@ TEST_CASE(nop_convert)
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = mm->add_parameter("x", s);
auto t = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, x);
auto t = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
x);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -314,10 +321,10 @@ TEST_CASE(concat_transpose1)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto xt = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto yt = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = mm->add_instruction(migraphx::op::concat{2}, xt, yt);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
auto xt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
auto yt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), y);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), xt, yt);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), concat);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -338,10 +345,10 @@ TEST_CASE(concat_transpose2)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto xt = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = mm->add_instruction(migraphx::op::concat{-1}, xt, yt);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
auto xt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x);
auto yt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), y);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", -1}}), xt, yt);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -362,10 +369,10 @@ TEST_CASE(concat_transpose3)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}});
auto xt = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = mm->add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
auto xt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x);
auto yt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), y);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -386,10 +393,10 @@ TEST_CASE(concat_transpose4)
auto sy = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}};
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto xt = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = mm->add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
auto xt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x);
auto yt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), y);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt);
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
mm->add_return({t});
migraphx::program p1 = p;
......@@ -406,9 +413,10 @@ TEST_CASE(nested_concat)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto concat1 = mm->add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = mm->add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = mm->add_instruction(migraphx::op::concat{1}, concat1, concat2);
auto concat1 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
auto concat2 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, x);
auto concat3 =
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), concat1, concat2);
mm->add_return({concat3});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -428,9 +436,10 @@ TEST_CASE(nested_concat_partial)
auto y = mm->add_parameter("y", s);
auto l = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 4, 3, 4}}));
auto concat1 = mm->add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = mm->add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = mm->add_instruction(migraphx::op::concat{1}, concat1, concat2, l);
auto concat1 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
auto concat2 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, x);
auto concat3 =
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), concat1, concat2, l);
mm->add_return({concat3});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
......@@ -448,8 +457,8 @@ TEST_CASE(multibroadcast_simplify)
std::vector<size_t> s_lens{1, 2, 3, 4};
auto s = migraphx::shape{migraphx::shape::float_type, s_lens};
auto x = mm->add_parameter("x", s);
auto y = mm->add_instruction(migraphx::op::multibroadcast{s_lens}, x);
mm->add_instruction(migraphx::op::mul{}, y, y);
auto y = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s_lens}}), x);
mm->add_instruction(migraphx::make_op("mul"), y, y);
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
......@@ -461,8 +470,10 @@ TEST_CASE(double_slice1)
auto* mm1 = p1.get_main_module();
{
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice1 = mm1->add_instruction(migraphx::op::slice{{0}, {32}, {256}}, x);
auto slice2 = mm1->add_instruction(migraphx::op::slice{{0}, {32}, {64}}, slice1);
auto slice1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {256}}}), x);
auto slice2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {64}}}), slice1);
mm1->add_return({slice2});
}
run_pass(p1);
......@@ -471,7 +482,8 @@ TEST_CASE(double_slice1)
auto* mm2 = p2.get_main_module();
{
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice = mm2->add_instruction(migraphx::op::slice{{0}, {64}, {96}}, x);
auto slice = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {96}}}), x);
mm2->add_return({slice});
}
EXPECT(p1 == p2);
......@@ -483,8 +495,10 @@ TEST_CASE(double_slice2)
auto* mm1 = p1.get_main_module();
{
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice1 = mm1->add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
auto slice2 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {32}}, slice1);
auto slice1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {128}}}), x);
auto slice2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {32}}}), slice1);
mm1->add_return({slice2});
}
run_pass(p1);
......@@ -493,7 +507,8 @@ TEST_CASE(double_slice2)
auto* mm2 = p2.get_main_module();
{
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice = mm2->add_instruction(migraphx::op::slice{{0}, {32}, {64}}, x);
auto slice = mm2->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {64}}}), x);
mm2->add_return({slice});
}
EXPECT(p1 == p2);
......@@ -505,8 +520,10 @@ TEST_CASE(double_slice_multi_axes)
auto* mm1 = p1.get_main_module();
{
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
auto slice1 = mm1->add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
auto slice2 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {32}}, slice1);
auto slice1 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {128}}}), x);
auto slice2 = mm1->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {32}}}), slice1);
mm1->add_return({slice2});
}
run_pass(p1);
......@@ -516,7 +533,10 @@ TEST_CASE(double_slice_multi_axes)
auto* mm2 = p2.get_main_module();
{
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
auto slice = mm2->add_instruction(migraphx::op::slice{{0, 1}, {32, 0}, {128, 32}}, x);
auto slice = mm2->add_instruction(
migraphx::make_op("slice",
{{"axes", {0, 1}}, {"starts", {32, 0}}, {"ends", {128, 32}}}),
x);
mm2->add_return({slice});
}
EXPECT(p1 == p2);
......
......@@ -10,6 +10,10 @@
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include "test.hpp"
migraphx::program
......@@ -39,7 +43,7 @@ TEST_CASE(add_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::op::add{}, l0, l1);
mm->add_instruction(migraphx::make_op("add"), l0, l1);
auto prog = optimize_tf("add_test.pb", false);
EXPECT(p == prog);
......@@ -50,7 +54,7 @@ TEST_CASE(addv2_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::add{}, l0, l1);
p.add_instruction(migraphx::make_op("add"), l0, l1);
auto prog = optimize_tf("addv2_test.pb", false);
EXPECT(p == prog);
......@@ -65,9 +69,11 @@ TEST_CASE(add_bcast_test)
migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
auto l2 = mm->add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0);
auto l3 = mm->add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
mm->add_instruction(migraphx::op::add{}, l2, l3);
auto l2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l0);
auto l3 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l2, l3);
auto prog = optimize_tf("add_bcast_test.pb", false);
EXPECT(p == prog);
......@@ -80,8 +86,8 @@ TEST_CASE(argmax_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}});
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
auto ins = mm->add_instruction(migraphx::op::argmax{2}, l0);
mm->add_instruction(migraphx::op::squeeze{{2}}, ins);
auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}});
EXPECT(p == prog);
......@@ -94,8 +100,8 @@ TEST_CASE(argmin_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
auto ins = mm->add_instruction(migraphx::op::argmin{2}, l0);
mm->add_instruction(migraphx::op::squeeze{{2}}, ins);
auto ins = mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), l0);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
auto prog = parse_tf("argmin_test.pb", false);
EXPECT(p == prog);
......@@ -111,9 +117,9 @@ TEST_CASE(assert_less_equal_test)
auto l1 = mm->add_parameter("1", s0);
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {0, 1}};
auto l2 = mm->add_literal(l);
mm->add_instruction(migraphx::op::add{}, l0, l1);
auto l3 = mm->add_instruction(migraphx::op::identity{}, l0, l1);
mm->add_instruction(migraphx::op::identity{}, l3, l2);
mm->add_instruction(migraphx::make_op("add"), l0, l1);
auto l3 = mm->add_instruction(migraphx::make_op("identity"), l0, l1);
mm->add_instruction(migraphx::make_op("identity"), l3, l2);
auto prog = optimize_tf("assert_less_equal_test.pb", false);
EXPECT(p == prog);
......@@ -127,10 +133,12 @@ TEST_CASE(batchmatmul_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}});
auto trans_l0 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto trans_l1 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto trans_l0 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l0);
auto trans_l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
mm->add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
auto prog = optimize_tf("batchmatmul_test.pb", false);
EXPECT(p == prog);
......@@ -193,8 +201,9 @@ TEST_CASE(biasadd_test)
uint64_t axis = 1;
auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = mm->add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
mm->add_instruction(migraphx::op::add{}, l0, l2);
auto l2 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_tf("biasadd_test.pb", true);
EXPECT(p == prog);
......@@ -210,8 +219,9 @@ TEST_CASE(biasadd_scalar_test)
auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}});
auto l2 = mm->add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
mm->add_instruction(migraphx::op::add{}, l0, l2);
auto l2 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_tf("biasadd_scalar_test.pb", true);
EXPECT(p == prog);
......@@ -223,7 +233,10 @@ TEST_CASE(cast_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, l0);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
l0);
auto prog = optimize_tf("cast_test.pb", false);
EXPECT(p == prog);
......@@ -243,7 +256,7 @@ TEST_CASE(concat_test)
// add the literal using a vector in order to set stride to 1 (like in tf parser)
mm->add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
mm->add_instruction(migraphx::op::concat{axis}, l0, l1);
mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1);
auto prog = optimize_tf("concat_test.pb", false);
EXPECT(p == prog);
......@@ -277,7 +290,7 @@ migraphx::program create_conv()
op.padding = {1, 1};
op.stride = {1, 1};
op.dilation = {1, 1};
auto l2 = mm->add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1);
mm->add_instruction(op, l0, l2);
return p;
}
......@@ -316,9 +329,9 @@ TEST_CASE(depthwiseconv_test)
op.stride = {1, 1};
op.dilation = {1, 1};
op.group = 3;
auto l3 = mm->add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
auto l4 = mm->add_instruction(migraphx::op::contiguous{}, l3);
auto l5 = mm->add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4);
auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1);
auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3);
auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4);
mm->add_instruction(op, l0, l5);
auto prog = optimize_tf("depthwise_conv_test.pb", true);
......@@ -333,7 +346,7 @@ TEST_CASE(expanddims_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
mm->add_literal(0);
mm->add_instruction(migraphx::op::reshape{{1, 2, 3, 4}}, l0);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 3, 4}}}), l0);
auto prog = optimize_tf("expanddims_test.pb", false);
EXPECT(p == prog);
......@@ -348,7 +361,7 @@ TEST_CASE(expanddims_test_neg_dims)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
mm->add_literal(-1);
mm->add_instruction(migraphx::op::reshape{{2, 3, 4, 1}}, l0);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4, 1}}}), l0);
auto prog = optimize_tf("expanddims_neg_test.pb", false);
EXPECT(p == prog);
......@@ -366,7 +379,7 @@ TEST_CASE(gather_test)
mm->add_literal(1);
int axis = 1;
mm->add_instruction(migraphx::op::gather{axis}, l0, l1);
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1);
auto prog = optimize_tf("gather_test.pb", false);
EXPECT(p == prog);
......@@ -378,7 +391,7 @@ TEST_CASE(identity_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::identity{}, l0);
mm->add_instruction(migraphx::make_op("identity"), l0);
auto prog = optimize_tf("identity_test.pb", false);
EXPECT(p == prog);
......@@ -392,10 +405,10 @@ TEST_CASE(matmul_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
auto trans_l0 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto trans_l1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto trans_l0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l0);
auto trans_l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
mm->add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
auto prog = optimize_tf("matmul_test.pb", false);
EXPECT(p == prog);
......@@ -413,7 +426,7 @@ TEST_CASE(mean_test)
migraphx::op::reduce_mean op{{2, 3}};
mm->add_instruction(op, l0);
auto l3 = mm->add_instruction(op, l0);
mm->add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), l3);
auto prog = optimize_tf("mean_test.pb", false);
EXPECT(p == prog);
......@@ -426,10 +439,10 @@ TEST_CASE(mean_test_nhwc)
auto* mm = p.get_main_module();
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
migraphx::op::reduce_mean op{{1, 2}};
auto l2 = mm->add_instruction(op, l1);
mm->add_instruction(migraphx::op::squeeze{{1, 2}}, l2);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2);
auto prog = optimize_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog);
......@@ -443,7 +456,7 @@ TEST_CASE(mul_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
mm->add_instruction(migraphx::op::mul{}, l0, l1);
mm->add_instruction(migraphx::make_op("mul"), l0, l1);
auto prog = optimize_tf("mul_test.pb", false);
EXPECT(p == prog);
......@@ -462,7 +475,7 @@ TEST_CASE(onehot_test)
auto l1 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}});
int axis = 0;
mm->add_instruction(migraphx::op::gather{axis}, l1, l0);
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l1, l0);
auto prog = optimize_tf("onehot_test.pb", false);
EXPECT(p == prog);
......@@ -488,13 +501,15 @@ TEST_CASE(pack_test)
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t axis = 1;
std::transform(args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return mm->add_instruction(migraphx::op::unsqueeze{{axis}}, arg);
});
mm->add_instruction(migraphx::op::concat{static_cast<int>(axis)}, unsqueezed_args);
std::transform(
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg);
});
mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast<int>(axis)}}),
unsqueezed_args);
auto prog = optimize_tf("pack_test.pb", false);
EXPECT(p == prog);
......@@ -506,11 +521,11 @@ TEST_CASE(pack_test_nhwc)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto lt0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l1);
auto lt1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l1);
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l2);
auto lt2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 3;
......@@ -519,9 +534,11 @@ TEST_CASE(pack_test_nhwc)
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return mm->add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
return mm->add_instruction(
migraphx::make_op("unsqueeze", {{"axes", {nchw_axis}}}), arg);
});
mm->add_instruction(migraphx::op::concat{static_cast<int>(nchw_axis)}, unsqueezed_args);
mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast<int>(nchw_axis)}}),
unsqueezed_args);
auto prog = optimize_tf("pack_test_nhwc.pb", true);
EXPECT(p == prog);
......@@ -552,7 +569,7 @@ TEST_CASE(pow_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::op::pow{}, l0, l1);
mm->add_instruction(migraphx::make_op("pow"), l0, l1);
auto prog = optimize_tf("pow_test.pb", false);
EXPECT(p == prog);
......@@ -564,7 +581,7 @@ TEST_CASE(relu_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::relu{}, l0);
mm->add_instruction(migraphx::make_op("relu"), l0);
auto prog = optimize_tf("relu_test.pb", false);
EXPECT(p == prog);
......@@ -579,9 +596,11 @@ TEST_CASE(relu6_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction(migraphx::op::multibroadcast{input_lens}, min_val);
max_val = mm->add_instruction(migraphx::op::multibroadcast{input_lens}, max_val);
mm->add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
max_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
auto prog = optimize_tf("relu6_test.pb", false);
EXPECT(p == prog);
......@@ -596,7 +615,7 @@ TEST_CASE(reshape_test)
migraphx::shape s0{migraphx::shape::int32_type, {4}};
// in tf, the second arg is a literal that contains new dimensions
mm->add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
mm->add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 1, 16}}}), l0);
auto prog = optimize_tf("reshape_test.pb", false);
EXPECT(p == prog);
......@@ -608,7 +627,7 @@ TEST_CASE(rsqrt_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::rsqrt{}, l0);
mm->add_instruction(migraphx::make_op("rsqrt"), l0);
auto prog = optimize_tf("rsqrt_test.pb", false);
EXPECT(p == prog);
......@@ -655,7 +674,7 @@ TEST_CASE(softmax_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
mm->add_instruction(migraphx::op::softmax{1}, l0);
mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), l0);
auto prog = optimize_tf("softmax_test.pb", false);
EXPECT(p == prog);
......@@ -672,11 +691,14 @@ TEST_CASE(split_test)
mm->add_literal(1); // split axis
mm->add_literal(1); // concat axis
mm->add_literal(1); // concat axis
auto l1 = mm->add_instruction(migraphx::op::slice{axes, {0, 0}, {5, 10}}, l0);
auto l2 = mm->add_instruction(migraphx::op::slice{axes, {0, 10}, {5, 20}}, l0);
auto l3 = mm->add_instruction(migraphx::op::slice{axes, {0, 20}, {5, 30}}, l0);
mm->add_instruction(migraphx::op::concat{1}, l1, l2);
mm->add_instruction(migraphx::op::concat{1}, l2, l3);
auto l1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 10}}}), l0);
auto l2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 10}}, {"ends", {5, 20}}}), l0);
auto l3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
auto prog = parse_tf("split_test.pb", false);
......@@ -691,7 +713,7 @@ TEST_CASE(split_test_one_output)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
mm->add_literal(1); // num_splits
mm->add_literal(1); // split axis
mm->add_instruction(migraphx::op::identity{}, l0);
mm->add_instruction(migraphx::make_op("identity"), l0);
auto prog = parse_tf("split_test_one_output.pb", false);
......@@ -711,11 +733,14 @@ TEST_CASE(split_test_vector_as_input)
mm->add_literal(1); // split axis
mm->add_literal(1); // concat axis
mm->add_literal(1); // concat axis
auto l1 = mm->add_instruction(migraphx::op::slice{axes, {0, 0}, {5, 4}}, l0);
auto l2 = mm->add_instruction(migraphx::op::slice{axes, {0, 4}, {5, 19}}, l0);
auto l3 = mm->add_instruction(migraphx::op::slice{axes, {0, 19}, {5, 30}}, l0);
mm->add_instruction(migraphx::op::concat{1}, l1, l2);
mm->add_instruction(migraphx::op::concat{1}, l2, l3);
auto l1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 4}}}), l0);
auto l2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 4}}, {"ends", {5, 19}}}), l0);
auto l3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
auto prog = parse_tf("split_test_vector_as_input.pb", false);
......@@ -729,7 +754,7 @@ TEST_CASE(sqdiff_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::op::sqdiff{}, l0, l1);
mm->add_instruction(migraphx::make_op("sqdiff"), l0, l1);
auto prog = optimize_tf("sqdiff_test.pb", false);
EXPECT(p == prog);
......@@ -741,7 +766,7 @@ TEST_CASE(squeeze_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
mm->add_instruction(migraphx::op::squeeze{{0, 3}}, l0);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 3}}}), l0);
auto prog = optimize_tf("squeeze_test.pb", false);
EXPECT(p == prog);
......@@ -753,7 +778,7 @@ TEST_CASE(stopgradient_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::identity{}, l0);
mm->add_instruction(migraphx::make_op("identity"), l0);
auto prog = optimize_tf("stopgradient_test.pb", false);
EXPECT(p == prog);
......@@ -765,7 +790,7 @@ TEST_CASE(stridedslice_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 0, 0, 0};
......@@ -774,7 +799,7 @@ TEST_CASE(stridedslice_test)
std::iota(op.axes.begin(), op.axes.end(), 0);
auto l2 = mm->add_instruction(op, l1);
auto shrink_axis = 1;
mm->add_instruction(migraphx::op::squeeze{{shrink_axis}}, l2);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2);
auto prog = optimize_tf("stridedslice_test.pb", true);
EXPECT(p == prog);
......@@ -800,9 +825,9 @@ TEST_CASE(stridedslice_masks_test)
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
std::vector<int>{1, 1, 1, 1});
auto l1 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
auto l2 = mm->add_instruction(op, l1);
mm->add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l2);
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), l2);
auto prog = parse_tf("stridedslice_masks_test.pb", true);
EXPECT(p == prog);
......@@ -815,7 +840,7 @@ TEST_CASE(sub_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::op::sub{}, l0, l1);
mm->add_instruction(migraphx::make_op("sub"), l0, l1);
auto prog = parse_tf("sub_test.pb", false);
EXPECT(p == prog);
......@@ -828,7 +853,7 @@ TEST_CASE(tanh_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::op::sub{}, l0, l1);
mm->add_instruction(migraphx::make_op("sub"), l0, l1);
auto prog = parse_tf("sub_test.pb", false);
EXPECT(p == prog);
......@@ -842,7 +867,7 @@ TEST_CASE(transpose_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::shape s0{migraphx::shape::int32_type, {4}};
mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
auto prog = optimize_tf("transpose_test.pb", false);
EXPECT(p == prog);
......@@ -854,7 +879,7 @@ TEST_CASE(variable_batch_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::identity{}, l0);
mm->add_instruction(migraphx::make_op("identity"), l0);
auto prog = optimize_tf("variable_batch_test.pb", false);
EXPECT(p == prog);
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
{
......@@ -14,12 +14,15 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
{
......@@ -17,7 +17,8 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(migraphx::op::quant_dot{1, 3}, l1, l2, l3);
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2, l3);
return p;
}
};
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
{
......@@ -14,10 +14,11 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::op::unsqueeze{{1}}, l2);
auto bul2 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 5, 1}}, ul2);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
auto bul2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 5, 1}}}), ul2);
mm->add_instruction(migraphx::op::dot{}, l1, bul2);
mm->add_instruction(migraphx::make_op("dot"), l1, bul2);
return p;
}
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
{
......@@ -14,9 +14,10 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
auto bl2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4}}}), l2);
mm->add_instruction(migraphx::op::dot{}, l1, bl2);
mm->add_instruction(migraphx::make_op("dot"), l1, bl2);
return p;
}
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
{
......@@ -14,9 +14,10 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
auto bl2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4}}}), l2);
mm->add_instruction(migraphx::op::dot{}, l1, bl2);
mm->add_instruction(migraphx::make_op("dot"), l1, bl2);
return p;
}
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
{
......@@ -13,10 +13,11 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = mm->add_parameter("2", m2_shape);
auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::op::dot{}, bl1, l2);
mm->add_instruction(migraphx::make_op("dot"), bl1, l2);
return p;
}
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
{
......@@ -13,10 +13,11 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = mm->add_parameter("2", m2_shape);
auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::op::dot{}, bl1, l2);
mm->add_instruction(migraphx::make_op("dot"), bl1, l2);
return p;
}
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
{
......@@ -13,10 +13,11 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = mm->add_parameter("2", m2_shape);
auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::op::dot{}, bl1, l2);
mm->add_instruction(migraphx::make_op("dot"), bl1, l2);
return p;
}
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
{
......@@ -13,11 +13,13 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, l2);
auto bl2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 3, 4}}}), l2);
mm->add_instruction(migraphx::op::dot{}, bl1, bl2);
mm->add_instruction(migraphx::make_op("dot"), bl1, bl2);
return p;
}
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
{
......@@ -13,10 +13,11 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = mm->add_parameter("2", m2_shape);
auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::op::dot{}, bl1, l2);
mm->add_instruction(migraphx::make_op("dot"), bl1, l2);
return p;
}
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_2args_mv : verify_program<gemm_2args_mv>
{
......@@ -14,9 +14,9 @@ struct gemm_2args_mv : verify_program<gemm_2args_mv>
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::op::unsqueeze{{1}}, l2);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
mm->add_instruction(migraphx::op::dot{}, l1, ul2);
mm->add_instruction(migraphx::make_op("dot"), l1, ul2);
return p;
}
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
{
......@@ -13,13 +13,14 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 2, 5, 4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto ul1 = mm->add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto bul1 = mm->add_instruction(migraphx::op::multibroadcast{{2, 2, 1, 5}}, ul1);
auto ul1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1);
auto bul1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 1, 5}}}), ul1);
auto l2 = mm->add_parameter("2", m2_shape);
auto res = mm->add_instruction(migraphx::op::dot{}, bul1, l2);
mm->add_instruction(migraphx::op::squeeze{{2}}, res);
auto res = mm->add_instruction(migraphx::make_op("dot"), bul1, l2);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), res);
return p;
}
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_2args_vm : verify_program<gemm_2args_vm>
{
......@@ -13,11 +13,11 @@ struct gemm_2args_vm : verify_program<gemm_2args_vm>
migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5, 4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto ul1 = mm->add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto ul1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape);
auto res = mm->add_instruction(migraphx::op::dot{}, ul1, l2);
mm->add_instruction(migraphx::op::squeeze{{0}}, res);
auto res = mm->add_instruction(migraphx::make_op("dot"), ul1, l2);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
return p;
}
......
......@@ -2,7 +2,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_2args_vv : verify_program<gemm_2args_vv>
{
......@@ -13,14 +13,14 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv>
migraphx::shape m1_shape{migraphx::shape::float_type, {8}};
migraphx::shape m2_shape{migraphx::shape::float_type, {8}};
auto l1 = mm->add_parameter("1", m1_shape);
auto ul1 = mm->add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto ul1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::op::unsqueeze{{1}}, l2);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
float alpha = 0.23f;
auto res = mm->add_instruction(migraphx::op::dot{alpha}, ul1, ul2);
auto sres = mm->add_instruction(migraphx::op::squeeze{{0}}, res);
mm->add_instruction(migraphx::op::squeeze{{0}}, sres);
auto res = mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ul1, ul2);
auto sres = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sres);
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment