Unverified Commit 0d2606bb authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Change attributes names to be more consistent and reflect better meaning (#916)

* rename broadcast and multibroadcast output_lens attribute to out_lens attribute, and change tests and source code to reflect the same

* change the reshape attribute from dims to out_lens

* change transpose attribute's name from dims to perm to reflect better meaning

* use permutation instead of perm for transpose

clang formaating

* use dims instead of out_lens for reshape

clang formatting
parent d8a2a933
...@@ -86,7 +86,8 @@ TEST_CASE(add_broadcast_test) ...@@ -86,7 +86,8 @@ TEST_CASE(add_broadcast_test)
auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data}); auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data});
auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data}); auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data});
auto l3 = mm->add_instruction( auto l3 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l1->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l1->get_shape().lens()}}),
l2);
mm->add_instruction(migraphx::make_op("add"), l1, l3); mm->add_instruction(migraphx::make_op("add"), l1, l3);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -105,10 +106,10 @@ TEST_CASE(add_broadcast_test) ...@@ -105,10 +106,10 @@ TEST_CASE(add_broadcast_test)
std::vector<float> b_data{0, -1, -2, -3}; std::vector<float> b_data{0, -1, -2, -3};
auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data}); auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data});
auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data}); auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data});
auto l3 = mm->add_instruction( auto l3 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 3}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3}}}), l1);
auto l4 = mm->add_instruction( auto l4 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 3}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3}}}), l2);
mm->add_instruction(migraphx::make_op("add"), l3, l4); mm->add_instruction(migraphx::make_op("add"), l3, l4);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -675,7 +676,7 @@ TEST_CASE(broadcast_test) ...@@ -675,7 +676,7 @@ TEST_CASE(broadcast_test)
auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data}); auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data});
auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data}); auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l1->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l1->get_shape().lens()}}), l2);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto output = result.get<int32_t>(); auto output = result.get<int32_t>();
...@@ -712,9 +713,9 @@ TEST_CASE(clip_test) ...@@ -712,9 +713,9 @@ TEST_CASE(clip_test)
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), min_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val);
max_val = max_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), max_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val);
mm->add_instruction(migraphx::make_op("clip"), l, min_val, max_val); mm->add_instruction(migraphx::make_op("clip"), l, min_val, max_val);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -1328,11 +1329,10 @@ TEST_CASE(equal_brcst_test) ...@@ -1328,11 +1329,10 @@ TEST_CASE(equal_brcst_test)
auto l0 = auto l0 =
mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}});
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}}); auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}});
auto bl1 = auto bl1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), l1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 3}}}), l1); auto eq = mm->add_instruction(migraphx::make_op("equal"), l0, bl1);
auto eq = mm->add_instruction(migraphx::make_op("equal"), l0, bl1); auto r = mm->add_instruction(
auto r = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
eq); eq);
...@@ -1675,11 +1675,10 @@ TEST_CASE(greater_brcst_test) ...@@ -1675,11 +1675,10 @@ TEST_CASE(greater_brcst_test)
auto l0 = auto l0 =
mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}});
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}}); auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}});
auto bl1 = auto bl1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), l1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 3}}}), l1); auto gr = mm->add_instruction(migraphx::make_op("greater"), l0, bl1);
auto gr = mm->add_instruction(migraphx::make_op("greater"), l0, bl1); auto r = mm->add_instruction(
auto r = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
gr); gr);
...@@ -2129,7 +2128,7 @@ TEST_CASE(imagescaler_test) ...@@ -2129,7 +2128,7 @@ TEST_CASE(imagescaler_test)
auto bias_vals = mm->add_literal( auto bias_vals = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto bias_bcast = mm->add_instruction( auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), bias_vals); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), bias_vals);
mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast); mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -2175,11 +2174,10 @@ TEST_CASE(less_brcst_test) ...@@ -2175,11 +2174,10 @@ TEST_CASE(less_brcst_test)
auto l0 = auto l0 =
mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}});
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}}); auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}});
auto bl1 = auto bl1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), l1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 3}}}), l1); auto le = mm->add_instruction(migraphx::make_op("less"), l0, bl1);
auto le = mm->add_instruction(migraphx::make_op("less"), l0, bl1); auto r = mm->add_instruction(
auto r = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
le); le);
...@@ -4198,8 +4196,9 @@ TEST_CASE(step_test) ...@@ -4198,8 +4196,9 @@ TEST_CASE(step_test)
std::iota(data.begin(), data.end(), 2); std::iota(data.begin(), data.end(), 2);
migraphx::shape s1{migraphx::shape::float_type, {2, 1, 4, 6}}; migraphx::shape s1{migraphx::shape::float_type, {2, 1, 4, 6}};
auto l0 = mm->add_literal(migraphx::literal{s1, data}); auto l0 = mm->add_literal(migraphx::literal{s1, data});
auto tl = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto tl = mm->add_instruction(
auto r = mm->add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto r = mm->add_instruction(
migraphx::make_op("step", {{"axes", {0, 1, 2}}, {"steps", {2, 2, 3}}}), tl); migraphx::make_op("step", {{"axes", {0, 1, 2}}, {"steps", {2, 2, 3}}}), tl);
mm->add_return({r}); mm->add_return({r});
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
...@@ -4272,7 +4271,7 @@ TEST_CASE(transpose_test) ...@@ -4272,7 +4271,7 @@ TEST_CASE(transpose_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data}); auto l = mm->add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> perm = {0, 3, 1, 2}; std::vector<int64_t> perm = {0, 3, 1, 2};
mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), l); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -4286,7 +4285,8 @@ TEST_CASE(transpose_test) ...@@ -4286,7 +4285,8 @@ TEST_CASE(transpose_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data}); auto l = mm->add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> perm = {0, 3, 1, 2}; std::vector<int64_t> perm = {0, 3, 1, 2};
auto result = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), l); auto result =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), l);
mm->add_instruction(migraphx::make_op("contiguous"), result); mm->add_instruction(migraphx::make_op("contiguous"), result);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result2 = p.eval({}).back(); auto result2 = p.eval({}).back();
......
...@@ -204,7 +204,7 @@ TEST_CASE(simplify_mul_conv1) ...@@ -204,7 +204,7 @@ TEST_CASE(simplify_mul_conv1)
w); w);
auto a = m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}})); auto a = m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto b = m.add_instruction( auto b = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 256, 14, 14}}}), a); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 256, 14, 14}}}), a);
auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b); auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b);
m.add_instruction(pass_op{}, mul); m.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul"); EXPECT(conv->outputs().front()->name() == "mul");
...@@ -226,7 +226,7 @@ TEST_CASE(simplify_mul_slice_conv1) ...@@ -226,7 +226,7 @@ TEST_CASE(simplify_mul_slice_conv1)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = m1.add_instruction( auto b = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a);
auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b); auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b);
auto slice2 = m1.add_instruction( auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
...@@ -244,7 +244,7 @@ TEST_CASE(simplify_mul_slice_conv1) ...@@ -244,7 +244,7 @@ TEST_CASE(simplify_mul_slice_conv1)
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w);
auto a = m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); auto a = m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = m2.add_instruction( auto b = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", {384, 1024, 1, 1}}}), a); migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {384, 1024, 1, 1}}}), a);
auto mul = m2.add_instruction(migraphx::make_op("mul"), b, wslice1); auto mul = m2.add_instruction(migraphx::make_op("mul"), b, wslice1);
auto wslice2 = m2.add_instruction( auto wslice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w);
...@@ -272,7 +272,7 @@ TEST_CASE(simplify_mul_slice_conv_overlapping_slice) ...@@ -272,7 +272,7 @@ TEST_CASE(simplify_mul_slice_conv_overlapping_slice)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = m1.add_instruction( auto b = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a);
auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b); auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b);
auto slice2 = m1.add_instruction( auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {383}}, {"ends", {767}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {383}}, {"ends", {767}}}), conv);
...@@ -296,7 +296,7 @@ TEST_CASE(simplify_mul_slice_conv_not_all_slice) ...@@ -296,7 +296,7 @@ TEST_CASE(simplify_mul_slice_conv_not_all_slice)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = m1.add_instruction( auto b = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a);
auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b); auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b);
auto c = m1.add_literal( auto c = m1.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {1, 768, 17, 17}})); migraphx::generate_literal({migraphx::shape::int32_type, {1, 768, 17, 17}}));
...@@ -1086,10 +1086,10 @@ TEST_CASE(simplify_slice_different_axis) ...@@ -1086,10 +1086,10 @@ TEST_CASE(simplify_slice_different_axis)
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {1}}}), input);
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
auto oneb = m1.add_instruction( auto oneb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {3, 1, 4, 2}}}), one); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4, 2}}}), one);
auto two = m1.add_literal(2); auto two = m1.add_literal(2);
auto twob = m1.add_instruction( auto twob = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 3}, {"dims", {3, 2, 4, 1}}}), two); migraphx::make_op("broadcast", {{"axis", 3}, {"out_lens", {3, 2, 4, 1}}}), two);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1);
auto reshape1 = m1.add_instruction(r, relu1); auto reshape1 = m1.add_instruction(r, relu1);
...@@ -1709,17 +1709,17 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion) ...@@ -1709,17 +1709,17 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
auto a1 = auto a1 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1)); m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = m1.add_instruction( auto b1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a1); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a1);
auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b1); auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b1);
auto a2 = auto a2 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2)); m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
auto b2 = m1.add_instruction( auto b2 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a2);
auto add1 = m1.add_instruction(migraphx::make_op("add"), mul, b2); auto add1 = m1.add_instruction(migraphx::make_op("add"), mul, b2);
auto a3 = auto a3 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3)); m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto b3 = m1.add_instruction( auto b3 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a3); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a3);
auto slice2 = m1.add_instruction( auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
auto add2 = m1.add_instruction(migraphx::make_op("add"), slice2, b3); auto add2 = m1.add_instruction(migraphx::make_op("add"), slice2, b3);
...@@ -1737,7 +1737,7 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion) ...@@ -1737,7 +1737,7 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
auto a1 = auto a1 =
m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1)); m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = m2.add_instruction( auto b1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", {384, 1024, 1, 1}}}), a1); migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {384, 1024, 1, 1}}}), a1);
auto mul = m2.add_instruction(migraphx::make_op("mul"), b1, wslice1); auto mul = m2.add_instruction(migraphx::make_op("mul"), b1, wslice1);
auto wslice2 = m2.add_instruction( auto wslice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w);
...@@ -1749,7 +1749,7 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion) ...@@ -1749,7 +1749,7 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3)); m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto concat2 = m2.add_instruction(migraphx::make_op("concat"), a2, a3); auto concat2 = m2.add_instruction(migraphx::make_op("concat"), a2, a3);
auto b4 = m2.add_instruction( auto b4 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 768, 17, 17}}}), concat2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 768, 17, 17}}}), concat2);
auto add = m2.add_instruction(migraphx::make_op("add"), conv, b4); auto add = m2.add_instruction(migraphx::make_op("add"), conv, b4);
auto slice1 = m2.add_instruction( auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), add); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), add);
...@@ -1785,9 +1785,9 @@ TEST_CASE(reorder_reshape_slice) ...@@ -1785,9 +1785,9 @@ TEST_CASE(reorder_reshape_slice)
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r0); auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r1); auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), r2); auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), r2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
...@@ -1810,9 +1810,12 @@ TEST_CASE(reorder_reshape_slice) ...@@ -1810,9 +1810,12 @@ TEST_CASE(reorder_reshape_slice)
auto slc2 = m2.add_instruction( auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {20}}, {"ends", {30}}}), r); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {20}}, {"ends", {30}}}), r);
auto t0 = m2.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0); auto t0 =
auto t1 = m2.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1); m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0);
auto t2 = m2.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2); auto t1 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1);
auto t2 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2);
auto sum = m2.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m2.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m2.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m2.add_instruction(migraphx::make_op("dot"), sum, t2);
...@@ -1857,9 +1860,9 @@ TEST_CASE(reorder_reshape_slice_move_axis1) ...@@ -1857,9 +1860,9 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r0); auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r1); auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), r2); auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), r2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
...@@ -1878,13 +1881,13 @@ TEST_CASE(reorder_reshape_slice_move_axis1) ...@@ -1878,13 +1881,13 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input); auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m.add_instruction( auto slc0 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp);
auto t0 = m.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0); auto t0 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0);
auto slc1 = m.add_instruction( auto slc1 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1); auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1);
auto slc2 = m.add_instruction( auto slc2 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp);
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2); auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2);
auto sum = m.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m.add_instruction(migraphx::make_op("dot"), sum, t2);
...@@ -2051,9 +2054,9 @@ TEST_CASE(reorder_slice_trans) ...@@ -2051,9 +2054,9 @@ TEST_CASE(reorder_slice_trans)
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}), migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input); input);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc0); auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc1); auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc2); auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, t2);
...@@ -2066,7 +2069,7 @@ TEST_CASE(reorder_slice_trans) ...@@ -2066,7 +2069,7 @@ TEST_CASE(reorder_slice_trans)
migraphx::module m2; migraphx::module m2;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}}; auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto r = m2.add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), input); auto r = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
auto slc0 = m2.add_instruction( auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {640}}}), r); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {640}}}), r);
...@@ -2110,9 +2113,12 @@ TEST_CASE(reorder_slice_trans_diff_perm) ...@@ -2110,9 +2113,12 @@ TEST_CASE(reorder_slice_trans_diff_perm)
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}), migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input); input);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0); auto t0 =
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1); m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2); auto t1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1);
auto t2 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
......
...@@ -30,12 +30,12 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m, ...@@ -30,12 +30,12 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m,
migraphx::instruction_ref scale_mb; migraphx::instruction_ref scale_mb;
if(scale->get_shape().lens().front() == 1) if(scale->get_shape().lens().front() == 1)
scale_mb = scale_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), scale); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale);
else else
scale_mb = scale_mb = m.add_instruction(
m.add_instruction(migraphx::make_op("broadcast", {{"axis", 1}, {"dims", lens}}), scale); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale);
auto shift_mb = auto shift_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), shift); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), shift);
return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb); return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb);
} }
...@@ -48,10 +48,10 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m, ...@@ -48,10 +48,10 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m,
migraphx::instruction_ref scale_mb; migraphx::instruction_ref scale_mb;
if(scale->get_shape().lens().front() == 1) if(scale->get_shape().lens().front() == 1)
scale_mb = scale_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), scale); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale);
else else
scale_mb = scale_mb = m.add_instruction(
m.add_instruction(migraphx::make_op("broadcast", {{"axis", 1}, {"dims", lens}}), scale); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale);
return m.add_instruction(migraphx::make_op(name), x, scale_mb); return m.add_instruction(migraphx::make_op(name), x, scale_mb);
} }
...@@ -402,7 +402,7 @@ TEST_CASE(conv_bias_add) ...@@ -402,7 +402,7 @@ TEST_CASE(conv_bias_add)
d5, d5,
d1); d1);
auto b1 = m1.add_instruction( auto b1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, b1); auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, b1);
m1.add_return({a1}); m1.add_return({a1});
} }
...@@ -428,7 +428,7 @@ TEST_CASE(conv_bias_add) ...@@ -428,7 +428,7 @@ TEST_CASE(conv_bias_add)
weights); weights);
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1); auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto b1 = m2.add_instruction( auto b1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d6, b1); auto a1 = m2.add_instruction(migraphx::make_op("add"), d6, b1);
m2.add_return({a1}); m2.add_return({a1});
} }
...@@ -470,7 +470,7 @@ TEST_CASE(conv_pooling_dot) ...@@ -470,7 +470,7 @@ TEST_CASE(conv_pooling_dot)
d5, d5,
d1); d1);
auto bc1 = m1.add_instruction( auto bc1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1); auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap = m1.add_instruction(migraphx::make_op("pooling", auto ap = m1.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", "average"},
...@@ -484,10 +484,10 @@ TEST_CASE(conv_pooling_dot) ...@@ -484,10 +484,10 @@ TEST_CASE(conv_pooling_dot)
auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero); auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
auto dot = auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4); m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4);
auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero); auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero); auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto mb1 = m1.add_instruction( auto mb1 =
migraphx::make_op("multibroadcast", {{"output_lens", {1, 1000}}}), d3); m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m1.add_instruction(migraphx::make_op("add"), d9, mb1); auto a2 = m1.add_instruction(migraphx::make_op("add"), d9, mb1);
m1.add_return({a2}); m1.add_return({a2});
} }
...@@ -517,7 +517,7 @@ TEST_CASE(conv_pooling_dot) ...@@ -517,7 +517,7 @@ TEST_CASE(conv_pooling_dot)
weights); weights);
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1); auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto bc1 = m2.add_instruction( auto bc1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1); auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap = m2.add_instruction(migraphx::make_op("pooling", auto ap = m2.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", "average"},
...@@ -530,9 +530,9 @@ TEST_CASE(conv_pooling_dot) ...@@ -530,9 +530,9 @@ TEST_CASE(conv_pooling_dot)
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero); auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot = auto dot =
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q4, db); m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q4, db);
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2); auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2);
auto mb1 = m2.add_instruction( auto mb1 =
migraphx::make_op("multibroadcast", {{"output_lens", {1, 1000}}}), d3); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1); auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1);
m2.add_return({a2}); m2.add_return({a2});
} }
...@@ -574,7 +574,7 @@ TEST_CASE(mobilenet_snippet) ...@@ -574,7 +574,7 @@ TEST_CASE(mobilenet_snippet)
d5, d5,
d1); d1);
auto bc1 = mm.add_instruction( auto bc1 = mm.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1); auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero); auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero); auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
...@@ -592,10 +592,10 @@ TEST_CASE(mobilenet_snippet) ...@@ -592,10 +592,10 @@ TEST_CASE(mobilenet_snippet)
auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero); auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero);
auto dot = auto dot =
mm.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4); mm.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4);
auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero); auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero); auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto mb1 = mm.add_instruction( auto mb1 =
migraphx::make_op("multibroadcast", {{"output_lens", {1, 1000}}}), d3); mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = mm.add_instruction(migraphx::make_op("add"), d9, mb1); auto a2 = mm.add_instruction(migraphx::make_op("add"), d9, mb1);
mm.add_return({a2}); mm.add_return({a2});
......
This diff is collapsed.
...@@ -87,7 +87,7 @@ TEST_CASE(add_bcast_test) ...@@ -87,7 +87,7 @@ TEST_CASE(add_bcast_test)
auto l0 = mm->add_parameter("0", s0); auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
auto l2 = auto l2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l2); mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_tf("add_bcast_test.pb", false); auto prog = optimize_tf("add_bcast_test.pb", false);
...@@ -151,9 +151,9 @@ TEST_CASE(batchmatmul_test) ...@@ -151,9 +151,9 @@ TEST_CASE(batchmatmul_test)
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}});
auto trans_l0 = auto trans_l0 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l0);
auto trans_l1 = auto trans_l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1); mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
auto prog = optimize_tf("batchmatmul_test.pb", false); auto prog = optimize_tf("batchmatmul_test.pb", false);
...@@ -220,7 +220,7 @@ TEST_CASE(biasadd_test) ...@@ -220,7 +220,7 @@ TEST_CASE(biasadd_test)
auto l0 = mm->add_parameter("0", s0); auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = mm->add_instruction( auto l2 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l2); mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_tf("biasadd_test.pb", true); auto prog = optimize_tf("biasadd_test.pb", true);
...@@ -238,7 +238,7 @@ TEST_CASE(biasadd_scalar_test) ...@@ -238,7 +238,7 @@ TEST_CASE(biasadd_scalar_test)
auto l1 = mm->add_literal( auto l1 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}});
auto l2 = mm->add_instruction( auto l2 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l2); mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_tf("biasadd_scalar_test.pb", true); auto prog = optimize_tf("biasadd_scalar_test.pb", true);
...@@ -308,7 +308,8 @@ migraphx::program create_conv() ...@@ -308,7 +308,8 @@ migraphx::program create_conv()
op.padding = {1, 1, 1, 1}; op.padding = {1, 1, 1, 1};
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1); auto l2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1);
mm->add_instruction(op, l0, l2); mm->add_instruction(op, l0, l2);
return p; return p;
} }
...@@ -359,10 +360,10 @@ TEST_CASE(conv_relu6_test) ...@@ -359,10 +360,10 @@ TEST_CASE(conv_relu6_test)
auto l0 = std::prev(mm->end()); auto l0 = std::prev(mm->end());
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction( min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val); min_val);
max_val = mm->add_instruction( max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val); max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
auto prog = optimize_tf("conv_relu6_test.pb", true); auto prog = optimize_tf("conv_relu6_test.pb", true);
...@@ -387,7 +388,8 @@ TEST_CASE(depthwiseconv_test) ...@@ -387,7 +388,8 @@ TEST_CASE(depthwiseconv_test)
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
op.group = 3; op.group = 3;
auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1); auto l3 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1);
auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3); auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3);
auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4); auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4);
mm->add_instruction(op, l0, l5); mm->add_instruction(op, l0, l5);
...@@ -463,8 +465,10 @@ TEST_CASE(matmul_test) ...@@ -463,8 +465,10 @@ TEST_CASE(matmul_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
auto trans_l0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l0); auto trans_l0 =
auto trans_l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l0);
auto trans_l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1); mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
auto prog = optimize_tf("matmul_test.pb", false); auto prog = optimize_tf("matmul_test.pb", false);
...@@ -497,7 +501,8 @@ TEST_CASE(mean_test_nhwc) ...@@ -497,7 +501,8 @@ TEST_CASE(mean_test_nhwc)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
migraphx::op::reduce_mean op{{1, 2}}; migraphx::op::reduce_mean op{{1, 2}};
auto l2 = mm->add_instruction(op, l1); auto l2 = mm->add_instruction(op, l1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2);
...@@ -595,11 +600,14 @@ TEST_CASE(pack_test_nhwc) ...@@ -595,11 +600,14 @@ TEST_CASE(pack_test_nhwc)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto lt0 =
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto lt1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l1); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto lt1 =
auto lt2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l1);
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2}; std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args; std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 3; int64_t nchw_axis = 3;
...@@ -688,10 +696,10 @@ TEST_CASE(relu6_test) ...@@ -688,10 +696,10 @@ TEST_CASE(relu6_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction( min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val); min_val);
max_val = mm->add_instruction( max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val); max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
auto prog = optimize_tf("relu6_test.pb", false); auto prog = optimize_tf("relu6_test.pb", false);
...@@ -882,7 +890,8 @@ TEST_CASE(stridedslice_test) ...@@ -882,7 +890,8 @@ TEST_CASE(stridedslice_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
std::size_t num_axes = 4; std::size_t num_axes = 4;
migraphx::op::slice op; migraphx::op::slice op;
op.starts = {0, 0, 0, 0}; op.starts = {0, 0, 0, 0};
...@@ -917,9 +926,11 @@ TEST_CASE(stridedslice_masks_test) ...@@ -917,9 +926,11 @@ TEST_CASE(stridedslice_masks_test)
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
std::vector<int>{1, 1, 1, 1}); std::vector<int>{1, 1, 1, 1});
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l2 = mm->add_instruction(op, l1); auto l2 = mm->add_instruction(op, l1);
auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), l2); auto l3 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2);
mm->add_return({l3}); mm->add_return({l3});
auto prog = parse_tf("stridedslice_masks_test.pb", true); auto prog = parse_tf("stridedslice_masks_test.pb", true);
...@@ -961,7 +972,7 @@ TEST_CASE(transpose_test) ...@@ -961,7 +972,7 @@ TEST_CASE(transpose_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::shape s0{migraphx::shape::int32_type, {4}}; migraphx::shape s0{migraphx::shape::int32_type, {4}};
mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}}); mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto prog = optimize_tf("transpose_test.pb", false); auto prog = optimize_tf("transpose_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
......
...@@ -14,12 +14,12 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1> ...@@ -14,12 +14,12 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = auto tl1 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = auto tl2 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3); migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
......
...@@ -13,12 +13,12 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4> ...@@ -13,12 +13,12 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}}; migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl1 = auto tl1 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 0, 1, 2}}}), l1); migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), l1);
auto tl2 = auto tl2 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 1, 2, 0}}}), l2); migraphx::make_op("transpose", {{"permutation", {3, 1, 2, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, tl2); mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, tl2);
return p; return p;
} }
......
...@@ -13,13 +13,13 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5> ...@@ -13,13 +13,13 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}}; migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl1 = auto tl1 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto sl1 = mm->add_instruction(migraphx::make_op("add"), tl1, tl1); auto sl1 = mm->add_instruction(migraphx::make_op("add"), tl1, tl1);
auto tl2 = auto tl2 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2); auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}}), sl1, sl2); mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}}), sl1, sl2);
return p; return p;
......
...@@ -16,7 +16,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv> ...@@ -16,7 +16,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2); auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
auto bul2 = mm->add_instruction( auto bul2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 5, 1}}}), ul2); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5, 1}}}), ul2);
mm->add_instruction(migraphx::make_op("dot"), l1, bul2); mm->add_instruction(migraphx::make_op("dot"), l1, bul2);
......
...@@ -12,10 +12,10 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1> ...@@ -12,10 +12,10 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction( auto bl2 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4}}}), l2);
mm->add_instruction(migraphx::make_op("dot"), l1, bl2); mm->add_instruction(migraphx::make_op("dot"), l1, bl2);
......
...@@ -12,10 +12,10 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2> ...@@ -12,10 +12,10 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction( auto bl2 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4}}}), l2);
mm->add_instruction(migraphx::make_op("dot"), l1, bl2); mm->add_instruction(migraphx::make_op("dot"), l1, bl2);
......
...@@ -12,9 +12,9 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3> ...@@ -12,9 +12,9 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction( auto bl1 =
migraphx::make_op("multibroadcast", {{"output_lens", {3, 2, 3}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::make_op("dot"), bl1, l2); mm->add_instruction(migraphx::make_op("dot"), bl1, l2);
......
...@@ -12,9 +12,9 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4> ...@@ -12,9 +12,9 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction( auto bl1 =
migraphx::make_op("multibroadcast", {{"output_lens", {3, 2, 3}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::make_op("dot"), bl1, l2); mm->add_instruction(migraphx::make_op("dot"), bl1, l2);
......
...@@ -14,7 +14,7 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5> ...@@ -14,7 +14,7 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction( auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 2, 3}}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::make_op("dot"), bl1, l2); mm->add_instruction(migraphx::make_op("dot"), bl1, l2);
......
...@@ -14,10 +14,10 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6> ...@@ -14,10 +14,10 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction( auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 2, 3}}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction( auto bl2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 3, 4}}}), l2); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 3, 4}}}), l2);
mm->add_instruction(migraphx::make_op("dot"), bl1, bl2); mm->add_instruction(migraphx::make_op("dot"), bl1, bl2);
......
...@@ -14,7 +14,7 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7> ...@@ -14,7 +14,7 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction( auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 2, 3}}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::make_op("dot"), bl1, l2); mm->add_instruction(migraphx::make_op("dot"), bl1, l2);
......
...@@ -15,7 +15,7 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm> ...@@ -15,7 +15,7 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto ul1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1); auto ul1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1);
auto bul1 = mm->add_instruction( auto bul1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 1, 5}}}), ul1); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 1, 5}}}), ul1);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
......
...@@ -12,9 +12,10 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose> ...@@ -12,9 +12,10 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), l2); auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), l2);
float alpha = 1.0f; float alpha = 1.0f;
float beta = 1.0f; float beta = 1.0f;
......
...@@ -14,10 +14,11 @@ struct quant_dot_3args_2 : verify_program<quant_dot_3args_2> ...@@ -14,10 +14,11 @@ struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
auto l2 = mm->add_parameter("b", m2_shape); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l3 = mm->add_parameter("c", m3_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, l2, l3); migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, l2, l3);
return p; return p;
......
...@@ -14,10 +14,11 @@ struct quant_dot_3args_3 : verify_program<quant_dot_3args_3> ...@@ -14,10 +14,11 @@ struct quant_dot_3args_3 : verify_program<quant_dot_3args_3>
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto tl2 =
auto l3 = mm->add_parameter("c", m3_shape); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3); migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3);
return p; return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment