Unverified Commit 0d2606bb authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Change attributes names to be more consistent and reflect better meaning (#916)

* rename broadcast and multibroadcast output_lens attribute to out_lens attribute, and change tests and source code to reflect the same

* change the reshape attribute from dims to out_lens

* change transpose attribute's name from dims to perm to reflect better meaning

* use permutation instead of perm for transpose

clang formaating

* use dims instead of out_lens for reshape

clang formatting
parent d8a2a933
...@@ -86,7 +86,8 @@ TEST_CASE(add_broadcast_test) ...@@ -86,7 +86,8 @@ TEST_CASE(add_broadcast_test)
auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data}); auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data});
auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data}); auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data});
auto l3 = mm->add_instruction( auto l3 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l1->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l1->get_shape().lens()}}),
l2);
mm->add_instruction(migraphx::make_op("add"), l1, l3); mm->add_instruction(migraphx::make_op("add"), l1, l3);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -105,10 +106,10 @@ TEST_CASE(add_broadcast_test) ...@@ -105,10 +106,10 @@ TEST_CASE(add_broadcast_test)
std::vector<float> b_data{0, -1, -2, -3}; std::vector<float> b_data{0, -1, -2, -3};
auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data}); auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data});
auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data}); auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data});
auto l3 = mm->add_instruction( auto l3 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 3}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3}}}), l1);
auto l4 = mm->add_instruction( auto l4 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 3}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3}}}), l2);
mm->add_instruction(migraphx::make_op("add"), l3, l4); mm->add_instruction(migraphx::make_op("add"), l3, l4);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -675,7 +676,7 @@ TEST_CASE(broadcast_test) ...@@ -675,7 +676,7 @@ TEST_CASE(broadcast_test)
auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data}); auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data});
auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data}); auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l1->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l1->get_shape().lens()}}), l2);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto output = result.get<int32_t>(); auto output = result.get<int32_t>();
...@@ -712,9 +713,9 @@ TEST_CASE(clip_test) ...@@ -712,9 +713,9 @@ TEST_CASE(clip_test)
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), min_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val);
max_val = max_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), max_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val);
mm->add_instruction(migraphx::make_op("clip"), l, min_val, max_val); mm->add_instruction(migraphx::make_op("clip"), l, min_val, max_val);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -1329,8 +1330,7 @@ TEST_CASE(equal_brcst_test) ...@@ -1329,8 +1330,7 @@ TEST_CASE(equal_brcst_test)
mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}});
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}}); auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}});
auto bl1 = auto bl1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), l1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 3}}}), l1);
auto eq = mm->add_instruction(migraphx::make_op("equal"), l0, bl1); auto eq = mm->add_instruction(migraphx::make_op("equal"), l0, bl1);
auto r = mm->add_instruction( auto r = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
...@@ -1676,8 +1676,7 @@ TEST_CASE(greater_brcst_test) ...@@ -1676,8 +1676,7 @@ TEST_CASE(greater_brcst_test)
mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}});
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}}); auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}});
auto bl1 = auto bl1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), l1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 3}}}), l1);
auto gr = mm->add_instruction(migraphx::make_op("greater"), l0, bl1); auto gr = mm->add_instruction(migraphx::make_op("greater"), l0, bl1);
auto r = mm->add_instruction( auto r = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
...@@ -2129,7 +2128,7 @@ TEST_CASE(imagescaler_test) ...@@ -2129,7 +2128,7 @@ TEST_CASE(imagescaler_test)
auto bias_vals = mm->add_literal( auto bias_vals = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto bias_bcast = mm->add_instruction( auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), bias_vals); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), bias_vals);
mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast); mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -2176,8 +2175,7 @@ TEST_CASE(less_brcst_test) ...@@ -2176,8 +2175,7 @@ TEST_CASE(less_brcst_test)
mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}});
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}}); auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}});
auto bl1 = auto bl1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), l1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 3}}}), l1);
auto le = mm->add_instruction(migraphx::make_op("less"), l0, bl1); auto le = mm->add_instruction(migraphx::make_op("less"), l0, bl1);
auto r = mm->add_instruction( auto r = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
...@@ -4198,7 +4196,8 @@ TEST_CASE(step_test) ...@@ -4198,7 +4196,8 @@ TEST_CASE(step_test)
std::iota(data.begin(), data.end(), 2); std::iota(data.begin(), data.end(), 2);
migraphx::shape s1{migraphx::shape::float_type, {2, 1, 4, 6}}; migraphx::shape s1{migraphx::shape::float_type, {2, 1, 4, 6}};
auto l0 = mm->add_literal(migraphx::literal{s1, data}); auto l0 = mm->add_literal(migraphx::literal{s1, data});
auto tl = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto tl = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto r = mm->add_instruction( auto r = mm->add_instruction(
migraphx::make_op("step", {{"axes", {0, 1, 2}}, {"steps", {2, 2, 3}}}), tl); migraphx::make_op("step", {{"axes", {0, 1, 2}}, {"steps", {2, 2, 3}}}), tl);
mm->add_return({r}); mm->add_return({r});
...@@ -4272,7 +4271,7 @@ TEST_CASE(transpose_test) ...@@ -4272,7 +4271,7 @@ TEST_CASE(transpose_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data}); auto l = mm->add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> perm = {0, 3, 1, 2}; std::vector<int64_t> perm = {0, 3, 1, 2};
mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), l); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -4286,7 +4285,8 @@ TEST_CASE(transpose_test) ...@@ -4286,7 +4285,8 @@ TEST_CASE(transpose_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data}); auto l = mm->add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> perm = {0, 3, 1, 2}; std::vector<int64_t> perm = {0, 3, 1, 2};
auto result = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), l); auto result =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), l);
mm->add_instruction(migraphx::make_op("contiguous"), result); mm->add_instruction(migraphx::make_op("contiguous"), result);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result2 = p.eval({}).back(); auto result2 = p.eval({}).back();
......
...@@ -204,7 +204,7 @@ TEST_CASE(simplify_mul_conv1) ...@@ -204,7 +204,7 @@ TEST_CASE(simplify_mul_conv1)
w); w);
auto a = m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}})); auto a = m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto b = m.add_instruction( auto b = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 256, 14, 14}}}), a); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 256, 14, 14}}}), a);
auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b); auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b);
m.add_instruction(pass_op{}, mul); m.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul"); EXPECT(conv->outputs().front()->name() == "mul");
...@@ -226,7 +226,7 @@ TEST_CASE(simplify_mul_slice_conv1) ...@@ -226,7 +226,7 @@ TEST_CASE(simplify_mul_slice_conv1)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = m1.add_instruction( auto b = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a);
auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b); auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b);
auto slice2 = m1.add_instruction( auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
...@@ -244,7 +244,7 @@ TEST_CASE(simplify_mul_slice_conv1) ...@@ -244,7 +244,7 @@ TEST_CASE(simplify_mul_slice_conv1)
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w);
auto a = m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); auto a = m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = m2.add_instruction( auto b = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", {384, 1024, 1, 1}}}), a); migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {384, 1024, 1, 1}}}), a);
auto mul = m2.add_instruction(migraphx::make_op("mul"), b, wslice1); auto mul = m2.add_instruction(migraphx::make_op("mul"), b, wslice1);
auto wslice2 = m2.add_instruction( auto wslice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w);
...@@ -272,7 +272,7 @@ TEST_CASE(simplify_mul_slice_conv_overlapping_slice) ...@@ -272,7 +272,7 @@ TEST_CASE(simplify_mul_slice_conv_overlapping_slice)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = m1.add_instruction( auto b = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a);
auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b); auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b);
auto slice2 = m1.add_instruction( auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {383}}, {"ends", {767}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {383}}, {"ends", {767}}}), conv);
...@@ -296,7 +296,7 @@ TEST_CASE(simplify_mul_slice_conv_not_all_slice) ...@@ -296,7 +296,7 @@ TEST_CASE(simplify_mul_slice_conv_not_all_slice)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = m1.add_instruction( auto b = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a);
auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b); auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b);
auto c = m1.add_literal( auto c = m1.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {1, 768, 17, 17}})); migraphx::generate_literal({migraphx::shape::int32_type, {1, 768, 17, 17}}));
...@@ -1086,10 +1086,10 @@ TEST_CASE(simplify_slice_different_axis) ...@@ -1086,10 +1086,10 @@ TEST_CASE(simplify_slice_different_axis)
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {1}}}), input);
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
auto oneb = m1.add_instruction( auto oneb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {3, 1, 4, 2}}}), one); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4, 2}}}), one);
auto two = m1.add_literal(2); auto two = m1.add_literal(2);
auto twob = m1.add_instruction( auto twob = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 3}, {"dims", {3, 2, 4, 1}}}), two); migraphx::make_op("broadcast", {{"axis", 3}, {"out_lens", {3, 2, 4, 1}}}), two);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1);
auto reshape1 = m1.add_instruction(r, relu1); auto reshape1 = m1.add_instruction(r, relu1);
...@@ -1709,17 +1709,17 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion) ...@@ -1709,17 +1709,17 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
auto a1 = auto a1 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1)); m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = m1.add_instruction( auto b1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a1); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a1);
auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b1); auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b1);
auto a2 = auto a2 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2)); m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
auto b2 = m1.add_instruction( auto b2 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a2);
auto add1 = m1.add_instruction(migraphx::make_op("add"), mul, b2); auto add1 = m1.add_instruction(migraphx::make_op("add"), mul, b2);
auto a3 = auto a3 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3)); m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto b3 = m1.add_instruction( auto b3 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a3); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a3);
auto slice2 = m1.add_instruction( auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
auto add2 = m1.add_instruction(migraphx::make_op("add"), slice2, b3); auto add2 = m1.add_instruction(migraphx::make_op("add"), slice2, b3);
...@@ -1737,7 +1737,7 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion) ...@@ -1737,7 +1737,7 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
auto a1 = auto a1 =
m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1)); m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = m2.add_instruction( auto b1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", {384, 1024, 1, 1}}}), a1); migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {384, 1024, 1, 1}}}), a1);
auto mul = m2.add_instruction(migraphx::make_op("mul"), b1, wslice1); auto mul = m2.add_instruction(migraphx::make_op("mul"), b1, wslice1);
auto wslice2 = m2.add_instruction( auto wslice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w);
...@@ -1749,7 +1749,7 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion) ...@@ -1749,7 +1749,7 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3)); m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto concat2 = m2.add_instruction(migraphx::make_op("concat"), a2, a3); auto concat2 = m2.add_instruction(migraphx::make_op("concat"), a2, a3);
auto b4 = m2.add_instruction( auto b4 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 768, 17, 17}}}), concat2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 768, 17, 17}}}), concat2);
auto add = m2.add_instruction(migraphx::make_op("add"), conv, b4); auto add = m2.add_instruction(migraphx::make_op("add"), conv, b4);
auto slice1 = m2.add_instruction( auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), add); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), add);
...@@ -1785,9 +1785,9 @@ TEST_CASE(reorder_reshape_slice) ...@@ -1785,9 +1785,9 @@ TEST_CASE(reorder_reshape_slice)
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r0); auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r1); auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), r2); auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), r2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
...@@ -1810,9 +1810,12 @@ TEST_CASE(reorder_reshape_slice) ...@@ -1810,9 +1810,12 @@ TEST_CASE(reorder_reshape_slice)
auto slc2 = m2.add_instruction( auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {20}}, {"ends", {30}}}), r); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {20}}, {"ends", {30}}}), r);
auto t0 = m2.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0); auto t0 =
auto t1 = m2.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1); m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0);
auto t2 = m2.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2); auto t1 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1);
auto t2 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2);
auto sum = m2.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m2.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m2.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m2.add_instruction(migraphx::make_op("dot"), sum, t2);
...@@ -1857,9 +1860,9 @@ TEST_CASE(reorder_reshape_slice_move_axis1) ...@@ -1857,9 +1860,9 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r0); auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r1); auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), r2); auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), r2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
...@@ -1878,13 +1881,13 @@ TEST_CASE(reorder_reshape_slice_move_axis1) ...@@ -1878,13 +1881,13 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input); auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m.add_instruction( auto slc0 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp);
auto t0 = m.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0); auto t0 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0);
auto slc1 = m.add_instruction( auto slc1 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1); auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1);
auto slc2 = m.add_instruction( auto slc2 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp);
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2); auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2);
auto sum = m.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m.add_instruction(migraphx::make_op("dot"), sum, t2);
...@@ -2051,9 +2054,9 @@ TEST_CASE(reorder_slice_trans) ...@@ -2051,9 +2054,9 @@ TEST_CASE(reorder_slice_trans)
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}), migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input); input);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc0); auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc1); auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc2); auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, t2);
...@@ -2066,7 +2069,7 @@ TEST_CASE(reorder_slice_trans) ...@@ -2066,7 +2069,7 @@ TEST_CASE(reorder_slice_trans)
migraphx::module m2; migraphx::module m2;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}}; auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto r = m2.add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), input); auto r = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
auto slc0 = m2.add_instruction( auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {640}}}), r); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {640}}}), r);
...@@ -2110,9 +2113,12 @@ TEST_CASE(reorder_slice_trans_diff_perm) ...@@ -2110,9 +2113,12 @@ TEST_CASE(reorder_slice_trans_diff_perm)
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}), migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input); input);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0); auto t0 =
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1); m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2); auto t1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1);
auto t2 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
......
...@@ -30,12 +30,12 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m, ...@@ -30,12 +30,12 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m,
migraphx::instruction_ref scale_mb; migraphx::instruction_ref scale_mb;
if(scale->get_shape().lens().front() == 1) if(scale->get_shape().lens().front() == 1)
scale_mb = scale_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), scale); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale);
else else
scale_mb = scale_mb = m.add_instruction(
m.add_instruction(migraphx::make_op("broadcast", {{"axis", 1}, {"dims", lens}}), scale); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale);
auto shift_mb = auto shift_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), shift); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), shift);
return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb); return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb);
} }
...@@ -48,10 +48,10 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m, ...@@ -48,10 +48,10 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m,
migraphx::instruction_ref scale_mb; migraphx::instruction_ref scale_mb;
if(scale->get_shape().lens().front() == 1) if(scale->get_shape().lens().front() == 1)
scale_mb = scale_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), scale); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale);
else else
scale_mb = scale_mb = m.add_instruction(
m.add_instruction(migraphx::make_op("broadcast", {{"axis", 1}, {"dims", lens}}), scale); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale);
return m.add_instruction(migraphx::make_op(name), x, scale_mb); return m.add_instruction(migraphx::make_op(name), x, scale_mb);
} }
...@@ -402,7 +402,7 @@ TEST_CASE(conv_bias_add) ...@@ -402,7 +402,7 @@ TEST_CASE(conv_bias_add)
d5, d5,
d1); d1);
auto b1 = m1.add_instruction( auto b1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, b1); auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, b1);
m1.add_return({a1}); m1.add_return({a1});
} }
...@@ -428,7 +428,7 @@ TEST_CASE(conv_bias_add) ...@@ -428,7 +428,7 @@ TEST_CASE(conv_bias_add)
weights); weights);
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1); auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto b1 = m2.add_instruction( auto b1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d6, b1); auto a1 = m2.add_instruction(migraphx::make_op("add"), d6, b1);
m2.add_return({a1}); m2.add_return({a1});
} }
...@@ -470,7 +470,7 @@ TEST_CASE(conv_pooling_dot) ...@@ -470,7 +470,7 @@ TEST_CASE(conv_pooling_dot)
d5, d5,
d1); d1);
auto bc1 = m1.add_instruction( auto bc1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1); auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap = m1.add_instruction(migraphx::make_op("pooling", auto ap = m1.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", "average"},
...@@ -486,8 +486,8 @@ TEST_CASE(conv_pooling_dot) ...@@ -486,8 +486,8 @@ TEST_CASE(conv_pooling_dot)
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4); m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4);
auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero); auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero); auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto mb1 = m1.add_instruction( auto mb1 =
migraphx::make_op("multibroadcast", {{"output_lens", {1, 1000}}}), d3); m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m1.add_instruction(migraphx::make_op("add"), d9, mb1); auto a2 = m1.add_instruction(migraphx::make_op("add"), d9, mb1);
m1.add_return({a2}); m1.add_return({a2});
} }
...@@ -517,7 +517,7 @@ TEST_CASE(conv_pooling_dot) ...@@ -517,7 +517,7 @@ TEST_CASE(conv_pooling_dot)
weights); weights);
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1); auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto bc1 = m2.add_instruction( auto bc1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1); auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap = m2.add_instruction(migraphx::make_op("pooling", auto ap = m2.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", "average"},
...@@ -531,8 +531,8 @@ TEST_CASE(conv_pooling_dot) ...@@ -531,8 +531,8 @@ TEST_CASE(conv_pooling_dot)
auto dot = auto dot =
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q4, db); m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q4, db);
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2); auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2);
auto mb1 = m2.add_instruction( auto mb1 =
migraphx::make_op("multibroadcast", {{"output_lens", {1, 1000}}}), d3); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1); auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1);
m2.add_return({a2}); m2.add_return({a2});
} }
...@@ -574,7 +574,7 @@ TEST_CASE(mobilenet_snippet) ...@@ -574,7 +574,7 @@ TEST_CASE(mobilenet_snippet)
d5, d5,
d1); d1);
auto bc1 = mm.add_instruction( auto bc1 = mm.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1); auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero); auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero); auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
...@@ -594,8 +594,8 @@ TEST_CASE(mobilenet_snippet) ...@@ -594,8 +594,8 @@ TEST_CASE(mobilenet_snippet)
mm.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4); mm.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4);
auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero); auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero); auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto mb1 = mm.add_instruction( auto mb1 =
migraphx::make_op("multibroadcast", {{"output_lens", {1, 1000}}}), d3); mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = mm.add_instruction(migraphx::make_op("add"), d9, mb1); auto a2 = mm.add_instruction(migraphx::make_op("add"), d9, mb1);
mm.add_return({a2}); mm.add_return({a2});
......
...@@ -22,7 +22,7 @@ TEST_CASE(double_contig) ...@@ -22,7 +22,7 @@ TEST_CASE(double_contig)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1); auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1);
auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), c1); auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), c1);
mm->add_return({c2}); mm->add_return({c2});
...@@ -42,8 +42,8 @@ TEST_CASE(double_transpose) ...@@ -42,8 +42,8 @@ TEST_CASE(double_transpose)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t1); auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t1);
mm->add_return({t2}); mm->add_return({t2});
EXPECT(mm->get_output_shapes().back().standard()); EXPECT(mm->get_output_shapes().back().standard());
EXPECT(not mm->get_output_shapes().back().transposed()); EXPECT(not mm->get_output_shapes().back().transposed());
...@@ -61,9 +61,9 @@ TEST_CASE(double_transpose_contig) ...@@ -61,9 +61,9 @@ TEST_CASE(double_transpose_contig)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1); auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1);
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), c1); auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), c1);
auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), t2); auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), t2);
mm->add_return({c2}); mm->add_return({c2});
EXPECT(mm->get_output_shapes().back().standard()); EXPECT(mm->get_output_shapes().back().standard());
...@@ -82,7 +82,7 @@ TEST_CASE(single_transpose) ...@@ -82,7 +82,7 @@ TEST_CASE(single_transpose)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
mm->add_return({t1}); mm->add_return({t1});
EXPECT(not mm->get_output_shapes().back().standard()); EXPECT(not mm->get_output_shapes().back().standard());
EXPECT(mm->get_output_shapes().back().transposed()); EXPECT(mm->get_output_shapes().back().transposed());
...@@ -100,8 +100,8 @@ TEST_CASE(double_transpose_sin_pass) ...@@ -100,8 +100,8 @@ TEST_CASE(double_transpose_sin_pass)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t1);
EXPECT(mm->get_output_shapes().back().standard()); EXPECT(mm->get_output_shapes().back().standard());
EXPECT(not mm->get_output_shapes().back().transposed()); EXPECT(not mm->get_output_shapes().back().transposed());
run_pass(*mm); run_pass(*mm);
...@@ -119,7 +119,7 @@ TEST_CASE(single_transpose_sin_pass) ...@@ -119,7 +119,7 @@ TEST_CASE(single_transpose_sin_pass)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
EXPECT(not mm->get_output_shapes().back().standard()); EXPECT(not mm->get_output_shapes().back().standard());
EXPECT(mm->get_output_shapes().back().transposed()); EXPECT(mm->get_output_shapes().back().transposed());
run_pass(*mm); run_pass(*mm);
...@@ -137,7 +137,8 @@ TEST_CASE(reshape_transpose) ...@@ -137,7 +137,8 @@ TEST_CASE(reshape_transpose)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto r1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 28, 56, 56}}}), x); auto r1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 28, 56, 56}}}), x);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1, 3, 4}}}), r1); auto t =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3, 4}}}), r1);
auto ct = m.add_instruction(migraphx::make_op("contiguous"), t); auto ct = m.add_instruction(migraphx::make_op("contiguous"), t);
auto r2 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 112, 56, 56}}}), ct); auto r2 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 112, 56, 56}}}), ct);
m.add_return({r2}); m.add_return({r2});
...@@ -154,7 +155,7 @@ TEST_CASE(transpose_contiguous) ...@@ -154,7 +155,7 @@ TEST_CASE(transpose_contiguous)
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), x); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x);
auto c1 = m.add_instruction(migraphx::make_op("contiguous"), t); auto c1 = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_return({c1}); m.add_return({c1});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
...@@ -170,7 +171,7 @@ TEST_CASE(transpose_double_contiguous) ...@@ -170,7 +171,7 @@ TEST_CASE(transpose_double_contiguous)
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), x); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x);
auto c1 = m.add_instruction(migraphx::make_op("contiguous"), t); auto c1 = m.add_instruction(migraphx::make_op("contiguous"), t);
auto c2 = m.add_instruction(migraphx::make_op("contiguous"), c1); auto c2 = m.add_instruction(migraphx::make_op("contiguous"), c1);
m.add_return({c2}); m.add_return({c2});
...@@ -188,8 +189,8 @@ TEST_CASE(transpose_partial1) ...@@ -188,8 +189,8 @@ TEST_CASE(transpose_partial1)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x); auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), x);
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1); auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), t1);
m.add_return({t2}); m.add_return({t2});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -204,9 +205,9 @@ TEST_CASE(transpose_partial2) ...@@ -204,9 +205,9 @@ TEST_CASE(transpose_partial2)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x); auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), x);
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1); auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), t1);
auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t2); auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), t2);
m.add_return({t3}); m.add_return({t3});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -221,10 +222,10 @@ TEST_CASE(transpose_partial3) ...@@ -221,10 +222,10 @@ TEST_CASE(transpose_partial3)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x); auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), x);
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1); auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), t1);
auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t2); auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), t2);
auto t4 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t3); auto t4 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), t3);
m.add_return({t4}); m.add_return({t4});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -239,7 +240,7 @@ TEST_CASE(nop_transpose1) ...@@ -239,7 +240,7 @@ TEST_CASE(nop_transpose1)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), x); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), x);
m.add_return({t}); m.add_return({t});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -254,10 +255,10 @@ TEST_CASE(nop_transpose2) ...@@ -254,10 +255,10 @@ TEST_CASE(nop_transpose2)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), x); auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), x);
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t1); auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), t1);
auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t2); auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), t2);
auto t4 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t3); auto t4 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), t3);
m.add_instruction(pass_op{}, t4); m.add_instruction(pass_op{}, t4);
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -274,8 +275,10 @@ TEST_CASE(nop_transpose3) ...@@ -274,8 +275,10 @@ TEST_CASE(nop_transpose3)
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto y = m.add_parameter("y", s); auto y = m.add_parameter("y", s);
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), x, y); auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), x, y);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2, 3}}}), concat); auto t1 =
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t1); m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), concat);
auto t2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t1);
m.add_return({t2}); m.add_return({t2});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -309,10 +312,11 @@ TEST_CASE(concat_transpose1) ...@@ -309,10 +312,11 @@ TEST_CASE(concat_transpose1)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto y = m.add_parameter("y", s); auto y = m.add_parameter("y", s);
auto xt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x); auto xt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), x);
auto yt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), y); auto yt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), y);
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), xt, yt); auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), xt, yt);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), concat); auto t =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), concat);
m.add_return({t}); m.add_return({t});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -332,10 +336,11 @@ TEST_CASE(concat_transpose2) ...@@ -332,10 +336,11 @@ TEST_CASE(concat_transpose2)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto y = m.add_parameter("y", s); auto y = m.add_parameter("y", s);
auto xt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x); auto xt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto yt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), y); auto yt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), y);
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", -1}}), xt, yt); auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", -1}}), xt, yt);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat); auto t =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), concat);
m.add_return({t}); m.add_return({t});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -355,10 +360,11 @@ TEST_CASE(concat_transpose3) ...@@ -355,10 +360,11 @@ TEST_CASE(concat_transpose3)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}); auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}});
auto y = m.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}}); auto y = m.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}});
auto xt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x); auto xt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto yt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), y); auto yt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), y);
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt); auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat); auto t =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), concat);
m.add_return({t}); m.add_return({t});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -378,10 +384,11 @@ TEST_CASE(concat_transpose4) ...@@ -378,10 +384,11 @@ TEST_CASE(concat_transpose4)
auto sy = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}}; auto sy = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}};
auto x = m.add_parameter("x", sx); auto x = m.add_parameter("x", sx);
auto y = m.add_parameter("y", sy); auto y = m.add_parameter("y", sy);
auto xt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x); auto xt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto yt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), y); auto yt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), y);
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt); auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat); auto t =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), concat);
m.add_return({t}); m.add_return({t});
migraphx::module m1 = m; migraphx::module m1 = m;
...@@ -438,7 +445,7 @@ TEST_CASE(multibroadcast_simplify) ...@@ -438,7 +445,7 @@ TEST_CASE(multibroadcast_simplify)
std::vector<size_t> s_lens{1, 2, 3, 4}; std::vector<size_t> s_lens{1, 2, 3, 4};
auto s = migraphx::shape{migraphx::shape::float_type, s_lens}; auto s = migraphx::shape{migraphx::shape::float_type, s_lens};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto y = m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s_lens}}), x); auto y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_lens}}), x);
m.add_instruction(migraphx::make_op("mul"), y, y); m.add_instruction(migraphx::make_op("mul"), y, y);
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
run_pass(m); run_pass(m);
...@@ -547,8 +554,8 @@ TEST_CASE(optimize_resize) ...@@ -547,8 +554,8 @@ TEST_CASE(optimize_resize)
std::vector<int64_t> dims = {1, 1, 2, 1, 2, 1}; std::vector<int64_t> dims = {1, 1, 2, 1, 2, 1};
auto rspx = m.add_instruction(migraphx::make_op("reshape", {{"dims", dims}}), inx); auto rspx = m.add_instruction(migraphx::make_op("reshape", {{"dims", dims}}), inx);
std::vector<int64_t> mb_dims = {1, 2, 2, 2, 2, 3}; std::vector<int64_t> mb_dims = {1, 2, 2, 2, 2, 3};
auto mbx = m.add_instruction( auto mbx =
migraphx::make_op("multibroadcast", {{"output_lens", mb_dims}}), rspx); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_dims}}), rspx);
auto std_mb = m.add_instruction(migraphx::make_op("contiguous"), mbx); auto std_mb = m.add_instruction(migraphx::make_op("contiguous"), mbx);
std::vector<int64_t> orig_dims = {1, 2, 4, 6}; std::vector<int64_t> orig_dims = {1, 2, 4, 6};
auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", orig_dims}}), std_mb); auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", orig_dims}}), std_mb);
...@@ -849,8 +856,8 @@ TEST_CASE(reshape_cont) ...@@ -849,8 +856,8 @@ TEST_CASE(reshape_cont)
auto inx = m.add_parameter("x", sx); auto inx = m.add_parameter("x", sx);
auto iny = m.add_parameter("y", sy); auto iny = m.add_parameter("y", sy);
auto mb_inx = m.add_instruction( auto mb_inx =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 4, 6}}}), inx); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), inx);
auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx); auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx);
auto rsp = auto rsp =
m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx); m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx);
...@@ -870,8 +877,8 @@ TEST_CASE(reshape_cont) ...@@ -870,8 +877,8 @@ TEST_CASE(reshape_cont)
auto inx = m.add_parameter("x", sx); auto inx = m.add_parameter("x", sx);
auto iny = m.add_parameter("y", sy); auto iny = m.add_parameter("y", sy);
auto mb_inx = m.add_instruction( auto mb_inx =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 4, 6}}}), inx); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), inx);
auto rsp_iny = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4, 6}}}), iny); auto rsp_iny = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4, 6}}}), iny);
auto sum = m.add_instruction(migraphx::make_op("add"), mb_inx, rsp_iny); auto sum = m.add_instruction(migraphx::make_op("add"), mb_inx, rsp_iny);
auto r = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), sum); auto r = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), sum);
...@@ -892,12 +899,13 @@ TEST_CASE(reshape_input_non_std) ...@@ -892,12 +899,13 @@ TEST_CASE(reshape_input_non_std)
auto inx = m.add_parameter("x", sx); auto inx = m.add_parameter("x", sx);
auto iny = m.add_parameter("y", sy); auto iny = m.add_parameter("y", sy);
auto mb_inx = m.add_instruction( auto mb_inx =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 4, 6}}}), inx); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), inx);
auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx); auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx);
auto rsp = auto rsp =
m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx); m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx);
auto ty = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), iny); auto ty =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), iny);
auto r = m.add_instruction(migraphx::make_op("add"), rsp, ty); auto r = m.add_instruction(migraphx::make_op("add"), rsp, ty);
m.add_return({r}); m.add_return({r});
...@@ -919,8 +927,8 @@ TEST_CASE(reshape_cont_nonpw) ...@@ -919,8 +927,8 @@ TEST_CASE(reshape_cont_nonpw)
auto inx = m.add_parameter("x", sx); auto inx = m.add_parameter("x", sx);
auto iny = m.add_parameter("y", sy); auto iny = m.add_parameter("y", sy);
auto mb_inx = m.add_instruction( auto mb_inx =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 4, 6}}}), inx); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), inx);
auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx); auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx);
auto rsp = auto rsp =
m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx); m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx);
......
...@@ -87,7 +87,7 @@ TEST_CASE(add_bcast_test) ...@@ -87,7 +87,7 @@ TEST_CASE(add_bcast_test)
auto l0 = mm->add_parameter("0", s0); auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
auto l2 = auto l2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l2); mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_tf("add_bcast_test.pb", false); auto prog = optimize_tf("add_bcast_test.pb", false);
...@@ -151,9 +151,9 @@ TEST_CASE(batchmatmul_test) ...@@ -151,9 +151,9 @@ TEST_CASE(batchmatmul_test)
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}});
auto trans_l0 = auto trans_l0 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l0);
auto trans_l1 = auto trans_l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1); mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
auto prog = optimize_tf("batchmatmul_test.pb", false); auto prog = optimize_tf("batchmatmul_test.pb", false);
...@@ -220,7 +220,7 @@ TEST_CASE(biasadd_test) ...@@ -220,7 +220,7 @@ TEST_CASE(biasadd_test)
auto l0 = mm->add_parameter("0", s0); auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = mm->add_instruction( auto l2 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l2); mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_tf("biasadd_test.pb", true); auto prog = optimize_tf("biasadd_test.pb", true);
...@@ -238,7 +238,7 @@ TEST_CASE(biasadd_scalar_test) ...@@ -238,7 +238,7 @@ TEST_CASE(biasadd_scalar_test)
auto l1 = mm->add_literal( auto l1 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}});
auto l2 = mm->add_instruction( auto l2 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l2); mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_tf("biasadd_scalar_test.pb", true); auto prog = optimize_tf("biasadd_scalar_test.pb", true);
...@@ -308,7 +308,8 @@ migraphx::program create_conv() ...@@ -308,7 +308,8 @@ migraphx::program create_conv()
op.padding = {1, 1, 1, 1}; op.padding = {1, 1, 1, 1};
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1); auto l2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1);
mm->add_instruction(op, l0, l2); mm->add_instruction(op, l0, l2);
return p; return p;
} }
...@@ -359,10 +360,10 @@ TEST_CASE(conv_relu6_test) ...@@ -359,10 +360,10 @@ TEST_CASE(conv_relu6_test)
auto l0 = std::prev(mm->end()); auto l0 = std::prev(mm->end());
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction( min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val); min_val);
max_val = mm->add_instruction( max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val); max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
auto prog = optimize_tf("conv_relu6_test.pb", true); auto prog = optimize_tf("conv_relu6_test.pb", true);
...@@ -387,7 +388,8 @@ TEST_CASE(depthwiseconv_test) ...@@ -387,7 +388,8 @@ TEST_CASE(depthwiseconv_test)
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
op.group = 3; op.group = 3;
auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1); auto l3 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1);
auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3); auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3);
auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4); auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4);
mm->add_instruction(op, l0, l5); mm->add_instruction(op, l0, l5);
...@@ -463,8 +465,10 @@ TEST_CASE(matmul_test) ...@@ -463,8 +465,10 @@ TEST_CASE(matmul_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
auto trans_l0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l0); auto trans_l0 =
auto trans_l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l0);
auto trans_l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1); mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
auto prog = optimize_tf("matmul_test.pb", false); auto prog = optimize_tf("matmul_test.pb", false);
...@@ -497,7 +501,8 @@ TEST_CASE(mean_test_nhwc) ...@@ -497,7 +501,8 @@ TEST_CASE(mean_test_nhwc)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
migraphx::op::reduce_mean op{{1, 2}}; migraphx::op::reduce_mean op{{1, 2}};
auto l2 = mm->add_instruction(op, l1); auto l2 = mm->add_instruction(op, l1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2);
...@@ -595,11 +600,14 @@ TEST_CASE(pack_test_nhwc) ...@@ -595,11 +600,14 @@ TEST_CASE(pack_test_nhwc)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto lt0 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l1); auto lt1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l1);
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l2); auto lt2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2}; std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args; std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 3; int64_t nchw_axis = 3;
...@@ -688,10 +696,10 @@ TEST_CASE(relu6_test) ...@@ -688,10 +696,10 @@ TEST_CASE(relu6_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction( min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val); min_val);
max_val = mm->add_instruction( max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val); max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
auto prog = optimize_tf("relu6_test.pb", false); auto prog = optimize_tf("relu6_test.pb", false);
...@@ -882,7 +890,8 @@ TEST_CASE(stridedslice_test) ...@@ -882,7 +890,8 @@ TEST_CASE(stridedslice_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
std::size_t num_axes = 4; std::size_t num_axes = 4;
migraphx::op::slice op; migraphx::op::slice op;
op.starts = {0, 0, 0, 0}; op.starts = {0, 0, 0, 0};
...@@ -917,9 +926,11 @@ TEST_CASE(stridedslice_masks_test) ...@@ -917,9 +926,11 @@ TEST_CASE(stridedslice_masks_test)
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
std::vector<int>{1, 1, 1, 1}); std::vector<int>{1, 1, 1, 1});
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l2 = mm->add_instruction(op, l1); auto l2 = mm->add_instruction(op, l1);
auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), l2); auto l3 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2);
mm->add_return({l3}); mm->add_return({l3});
auto prog = parse_tf("stridedslice_masks_test.pb", true); auto prog = parse_tf("stridedslice_masks_test.pb", true);
...@@ -961,7 +972,7 @@ TEST_CASE(transpose_test) ...@@ -961,7 +972,7 @@ TEST_CASE(transpose_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::shape s0{migraphx::shape::int32_type, {4}}; migraphx::shape s0{migraphx::shape::int32_type, {4}};
mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}}); mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto prog = optimize_tf("transpose_test.pb", false); auto prog = optimize_tf("transpose_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
......
...@@ -15,11 +15,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1> ...@@ -15,11 +15,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = auto tl1 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = auto tl2 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3); migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
......
...@@ -15,10 +15,10 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4> ...@@ -15,10 +15,10 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl1 = auto tl1 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 0, 1, 2}}}), l1); migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), l1);
auto tl2 = auto tl2 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 1, 2, 0}}}), l2); migraphx::make_op("transpose", {{"permutation", {3, 1, 2, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, tl2); mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, tl2);
return p; return p;
} }
......
...@@ -15,11 +15,11 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5> ...@@ -15,11 +15,11 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl1 = auto tl1 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto sl1 = mm->add_instruction(migraphx::make_op("add"), tl1, tl1); auto sl1 = mm->add_instruction(migraphx::make_op("add"), tl1, tl1);
auto tl2 = auto tl2 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2); auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}}), sl1, sl2); mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}}), sl1, sl2);
return p; return p;
......
...@@ -16,7 +16,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv> ...@@ -16,7 +16,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2); auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
auto bul2 = mm->add_instruction( auto bul2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 5, 1}}}), ul2); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5, 1}}}), ul2);
mm->add_instruction(migraphx::make_op("dot"), l1, bul2); mm->add_instruction(migraphx::make_op("dot"), l1, bul2);
......
...@@ -14,8 +14,8 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1> ...@@ -14,8 +14,8 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction( auto bl2 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4}}}), l2);
mm->add_instruction(migraphx::make_op("dot"), l1, bl2); mm->add_instruction(migraphx::make_op("dot"), l1, bl2);
......
...@@ -14,8 +14,8 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2> ...@@ -14,8 +14,8 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction( auto bl2 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4}}}), l2);
mm->add_instruction(migraphx::make_op("dot"), l1, bl2); mm->add_instruction(migraphx::make_op("dot"), l1, bl2);
......
...@@ -13,8 +13,8 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3> ...@@ -13,8 +13,8 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction( auto bl1 =
migraphx::make_op("multibroadcast", {{"output_lens", {3, 2, 3}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::make_op("dot"), bl1, l2); mm->add_instruction(migraphx::make_op("dot"), bl1, l2);
......
...@@ -13,8 +13,8 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4> ...@@ -13,8 +13,8 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction( auto bl1 =
migraphx::make_op("multibroadcast", {{"output_lens", {3, 2, 3}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::make_op("dot"), bl1, l2); mm->add_instruction(migraphx::make_op("dot"), bl1, l2);
......
...@@ -14,7 +14,7 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5> ...@@ -14,7 +14,7 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction( auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 2, 3}}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::make_op("dot"), bl1, l2); mm->add_instruction(migraphx::make_op("dot"), bl1, l2);
......
...@@ -14,10 +14,10 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6> ...@@ -14,10 +14,10 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction( auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 2, 3}}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction( auto bl2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 3, 4}}}), l2); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 3, 4}}}), l2);
mm->add_instruction(migraphx::make_op("dot"), bl1, bl2); mm->add_instruction(migraphx::make_op("dot"), bl1, bl2);
......
...@@ -14,7 +14,7 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7> ...@@ -14,7 +14,7 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction( auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 2, 3}}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::make_op("dot"), bl1, l2); mm->add_instruction(migraphx::make_op("dot"), bl1, l2);
......
...@@ -15,7 +15,7 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm> ...@@ -15,7 +15,7 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto ul1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1); auto ul1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1);
auto bul1 = mm->add_instruction( auto bul1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 1, 5}}}), ul1); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 1, 5}}}), ul1);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
......
...@@ -14,7 +14,8 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose> ...@@ -14,7 +14,8 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose>
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), l2); auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), l2);
float alpha = 1.0f; float alpha = 1.0f;
float beta = 1.0f; float beta = 1.0f;
......
...@@ -15,7 +15,8 @@ struct quant_dot_3args_2 : verify_program<quant_dot_3args_2> ...@@ -15,7 +15,8 @@ struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction( mm->add_instruction(
......
...@@ -16,7 +16,8 @@ struct quant_dot_3args_3 : verify_program<quant_dot_3args_3> ...@@ -16,7 +16,8 @@ struct quant_dot_3args_3 : verify_program<quant_dot_3args_3>
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3); migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment