Commit 7e297b13 authored by Paul's avatar Paul
Browse files

Merge

parents 86ea5e91 aa7ff911
...@@ -48,7 +48,7 @@ TEST_CASE(as_json) ...@@ -48,7 +48,7 @@ TEST_CASE(as_json)
TEST_CASE(as_file) TEST_CASE(as_file)
{ {
std::string filename = "migraphx_program.dat"; std::string filename = "migraphx_program.mxr";
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::save(p1, filename); migraphx::save(p1, filename);
migraphx::program p2 = migraphx::load(filename); migraphx::program p2 = migraphx::load(filename);
......
...@@ -608,4 +608,15 @@ TEST_CASE(cpp_type_name) ...@@ -608,4 +608,15 @@ TEST_CASE(cpp_type_name)
EXPECT(test::throws([&] { migraphx::shape::cpp_type(migraphx::shape::tuple_type); })); EXPECT(test::throws([&] { migraphx::shape::cpp_type(migraphx::shape::tuple_type); }));
} }
TEST_CASE(test_with_type)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}};
EXPECT(s.type() == migraphx::shape::float_type);
auto new_s = s.with_type(migraphx::shape::half_type);
EXPECT(s.type() == migraphx::shape::float_type);
EXPECT(s.type() != new_s.type());
EXPECT(s.lens() == new_s.lens());
EXPECT(s.strides() == new_s.strides());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -204,7 +204,7 @@ TEST_CASE(simplify_mul_conv1) ...@@ -204,7 +204,7 @@ TEST_CASE(simplify_mul_conv1)
w); w);
auto a = m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}})); auto a = m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto b = m.add_instruction( auto b = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 256, 14, 14}}}), a); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 256, 14, 14}}}), a);
auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b); auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b);
m.add_instruction(pass_op{}, mul); m.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul"); EXPECT(conv->outputs().front()->name() == "mul");
...@@ -226,7 +226,7 @@ TEST_CASE(simplify_mul_slice_conv1) ...@@ -226,7 +226,7 @@ TEST_CASE(simplify_mul_slice_conv1)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = m1.add_instruction( auto b = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a);
auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b); auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b);
auto slice2 = m1.add_instruction( auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
...@@ -244,7 +244,7 @@ TEST_CASE(simplify_mul_slice_conv1) ...@@ -244,7 +244,7 @@ TEST_CASE(simplify_mul_slice_conv1)
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w);
auto a = m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); auto a = m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = m2.add_instruction( auto b = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", {384, 1024, 1, 1}}}), a); migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {384, 1024, 1, 1}}}), a);
auto mul = m2.add_instruction(migraphx::make_op("mul"), b, wslice1); auto mul = m2.add_instruction(migraphx::make_op("mul"), b, wslice1);
auto wslice2 = m2.add_instruction( auto wslice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w);
...@@ -272,7 +272,7 @@ TEST_CASE(simplify_mul_slice_conv_overlapping_slice) ...@@ -272,7 +272,7 @@ TEST_CASE(simplify_mul_slice_conv_overlapping_slice)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = m1.add_instruction( auto b = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a);
auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b); auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b);
auto slice2 = m1.add_instruction( auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {383}}, {"ends", {767}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {383}}, {"ends", {767}}}), conv);
...@@ -296,7 +296,7 @@ TEST_CASE(simplify_mul_slice_conv_not_all_slice) ...@@ -296,7 +296,7 @@ TEST_CASE(simplify_mul_slice_conv_not_all_slice)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = m1.add_instruction( auto b = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a);
auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b); auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b);
auto c = m1.add_literal( auto c = m1.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {1, 768, 17, 17}})); migraphx::generate_literal({migraphx::shape::int32_type, {1, 768, 17, 17}}));
...@@ -1086,10 +1086,10 @@ TEST_CASE(simplify_slice_different_axis) ...@@ -1086,10 +1086,10 @@ TEST_CASE(simplify_slice_different_axis)
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {1}}}), input);
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
auto oneb = m1.add_instruction( auto oneb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {3, 1, 4, 2}}}), one); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4, 2}}}), one);
auto two = m1.add_literal(2); auto two = m1.add_literal(2);
auto twob = m1.add_instruction( auto twob = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 3}, {"dims", {3, 2, 4, 1}}}), two); migraphx::make_op("broadcast", {{"axis", 3}, {"out_lens", {3, 2, 4, 1}}}), two);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1);
auto reshape1 = m1.add_instruction(r, relu1); auto reshape1 = m1.add_instruction(r, relu1);
...@@ -1709,17 +1709,17 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion) ...@@ -1709,17 +1709,17 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
auto a1 = auto a1 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1)); m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = m1.add_instruction( auto b1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a1); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a1);
auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b1); auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b1);
auto a2 = auto a2 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2)); m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
auto b2 = m1.add_instruction( auto b2 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a2);
auto add1 = m1.add_instruction(migraphx::make_op("add"), mul, b2); auto add1 = m1.add_instruction(migraphx::make_op("add"), mul, b2);
auto a3 = auto a3 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3)); m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto b3 = m1.add_instruction( auto b3 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a3); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a3);
auto slice2 = m1.add_instruction( auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
auto add2 = m1.add_instruction(migraphx::make_op("add"), slice2, b3); auto add2 = m1.add_instruction(migraphx::make_op("add"), slice2, b3);
...@@ -1737,7 +1737,7 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion) ...@@ -1737,7 +1737,7 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
auto a1 = auto a1 =
m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1)); m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = m2.add_instruction( auto b1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", {384, 1024, 1, 1}}}), a1); migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {384, 1024, 1, 1}}}), a1);
auto mul = m2.add_instruction(migraphx::make_op("mul"), b1, wslice1); auto mul = m2.add_instruction(migraphx::make_op("mul"), b1, wslice1);
auto wslice2 = m2.add_instruction( auto wslice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w);
...@@ -1749,7 +1749,7 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion) ...@@ -1749,7 +1749,7 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3)); m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto concat2 = m2.add_instruction(migraphx::make_op("concat"), a2, a3); auto concat2 = m2.add_instruction(migraphx::make_op("concat"), a2, a3);
auto b4 = m2.add_instruction( auto b4 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 768, 17, 17}}}), concat2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 768, 17, 17}}}), concat2);
auto add = m2.add_instruction(migraphx::make_op("add"), conv, b4); auto add = m2.add_instruction(migraphx::make_op("add"), conv, b4);
auto slice1 = m2.add_instruction( auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), add); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), add);
...@@ -1785,9 +1785,9 @@ TEST_CASE(reorder_reshape_slice) ...@@ -1785,9 +1785,9 @@ TEST_CASE(reorder_reshape_slice)
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r0); auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r1); auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), r2); auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), r2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
...@@ -1810,9 +1810,12 @@ TEST_CASE(reorder_reshape_slice) ...@@ -1810,9 +1810,12 @@ TEST_CASE(reorder_reshape_slice)
auto slc2 = m2.add_instruction( auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {20}}, {"ends", {30}}}), r); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {20}}, {"ends", {30}}}), r);
auto t0 = m2.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0); auto t0 =
auto t1 = m2.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1); m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0);
auto t2 = m2.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2); auto t1 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1);
auto t2 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2);
auto sum = m2.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m2.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m2.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m2.add_instruction(migraphx::make_op("dot"), sum, t2);
...@@ -1857,9 +1860,9 @@ TEST_CASE(reorder_reshape_slice_move_axis1) ...@@ -1857,9 +1860,9 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r0); auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r1); auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), r2); auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), r2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
...@@ -1878,13 +1881,13 @@ TEST_CASE(reorder_reshape_slice_move_axis1) ...@@ -1878,13 +1881,13 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input); auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m.add_instruction( auto slc0 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp);
auto t0 = m.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0); auto t0 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0);
auto slc1 = m.add_instruction( auto slc1 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1); auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1);
auto slc2 = m.add_instruction( auto slc2 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp);
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2); auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2);
auto sum = m.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m.add_instruction(migraphx::make_op("dot"), sum, t2);
...@@ -2051,9 +2054,9 @@ TEST_CASE(reorder_slice_trans) ...@@ -2051,9 +2054,9 @@ TEST_CASE(reorder_slice_trans)
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}), migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input); input);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc0); auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc1); auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc2); auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, t2);
...@@ -2066,7 +2069,7 @@ TEST_CASE(reorder_slice_trans) ...@@ -2066,7 +2069,7 @@ TEST_CASE(reorder_slice_trans)
migraphx::module m2; migraphx::module m2;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}}; auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto r = m2.add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), input); auto r = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
auto slc0 = m2.add_instruction( auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {640}}}), r); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {640}}}), r);
...@@ -2110,9 +2113,12 @@ TEST_CASE(reorder_slice_trans_diff_perm) ...@@ -2110,9 +2113,12 @@ TEST_CASE(reorder_slice_trans_diff_perm)
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}), migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input); input);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0); auto t0 =
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1); m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2); auto t1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1);
auto t2 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
......
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/instruction.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/apply_alpha_beta.hpp>
bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; }
bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; }
void run_pass(migraphx::module& m)
{
migraphx::simplify_qdq sqdq;
sqdq.apply(m);
}
migraphx::instruction_ref add_quantize_op(migraphx::module& m,
const std::string& name,
migraphx::instruction_ref x,
migraphx::instruction_ref scale,
migraphx::instruction_ref shift)
{
auto lens = x->get_shape().lens();
migraphx::instruction_ref scale_mb;
if(scale->get_shape().lens().front() == 1)
scale_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale);
else
scale_mb = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale);
auto shift_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), shift);
return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb);
}
migraphx::instruction_ref add_quantize_op(migraphx::module& m,
const std::string& name,
migraphx::instruction_ref x,
migraphx::instruction_ref scale)
{
auto lens = x->get_shape().lens();
migraphx::instruction_ref scale_mb;
if(scale->get_shape().lens().front() == 1)
scale_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale);
else
scale_mb = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale);
return m.add_instruction(migraphx::make_op(name), x, scale_mb);
}
TEST_CASE(remove_qdq)
{
migraphx::shape sh1{migraphx::shape::float_type, {100, 100}};
migraphx::shape sh2{migraphx::shape::float_type, {100, 100}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto add = m1.add_instruction(migraphx::make_op("add"), d1, d2);
m1.add_return({add});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto add = m2.add_instruction(migraphx::make_op("add"), t1, t2);
m2.add_return({add});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(qdq_different_scales)
{
migraphx::shape sh1{migraphx::shape::float_type, {100, 100}};
migraphx::shape sh2{migraphx::shape::float_type, {100, 100}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale1 = m1.add_literal(0.5f);
auto scale2 = m1.add_literal(0.4f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale2, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale1, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero);
auto add = m1.add_instruction(migraphx::make_op("add"), d1, d2);
m1.add_return({add});
}
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_non_zero_point)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{1});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2);
m2.add_return({dot});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_uint8)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::uint8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2);
m2.add_return({dot});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_add)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}};
migraphx::shape sh3{migraphx::shape::float_type, {1280, 1024}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto ab = m1.add_parameter("ab", sh3);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
auto q3 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
auto add = m1.add_instruction(migraphx::make_op("add"), d3, ab);
m1.add_return({add});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto ab = m2.add_parameter("ab", sh3);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab);
m2.add_return({add});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(conv)
{
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s7);
auto weights = m1.add_parameter("weights", s4);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero);
auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero);
auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto c1 = m1.add_instruction(migraphx::make_op("convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
d5,
d1);
m1.add_return({c1});
}
migraphx::module m2;
{
auto input = m2.add_parameter("input", s7);
auto weights = m2.add_parameter("weights", s4);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q1,
weights);
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
m2.add_return({d6});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(conv_multi_scale)
{
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}};
migraphx::shape s8{migraphx::shape::float_type, {320}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s7);
auto weights = m1.add_parameter("weights", s4);
auto scale = m1.add_literal(migraphx::generate_literal(s8, 0));
auto zero = m1.add_literal(std::int8_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero);
auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero);
auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto c1 = m1.add_instruction(migraphx::make_op("convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
d5,
d1);
m1.add_return({c1});
}
migraphx::module m2;
{
auto input = m2.add_parameter("input", s7);
auto weights = m2.add_parameter("weights", s4);
auto scale = m2.add_literal(migraphx::generate_literal(s8, 0));
auto zero = m2.add_literal(std::int8_t{0});
auto d1 = add_quantize_op(m2, "dequantizelinear", weights, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
input,
d1);
m2.add_return({c1});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(conv_bias_add)
{
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
migraphx::shape s6{migraphx::shape::int32_type, {1280}};
migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s7);
auto weights = m1.add_parameter("weights", s4);
auto bias = m1.add_parameter("bias", s6);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero);
auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero);
auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto c1 = m1.add_instruction(migraphx::make_op("convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
d5,
d1);
auto b1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, b1);
m1.add_return({a1});
}
migraphx::module m2;
{
auto input = m2.add_parameter("input", s7);
auto weights = m2.add_parameter("weights", s4);
auto bias = m2.add_parameter("bias", s6);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q1,
weights);
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto b1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d6, b1);
m2.add_return({a1});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(conv_pooling_dot)
{
migraphx::shape s2{migraphx::shape::int8_type, {1280, 1000}};
migraphx::shape s3{migraphx::shape::int8_type, {1000}};
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
migraphx::shape s6{migraphx::shape::int32_type, {1280}};
migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}};
migraphx::module m1;
{
auto db = m1.add_parameter("db", s2); // dot input b
auto ab = m1.add_parameter("ab", s3); // add input b
auto weights = m1.add_parameter("weights", s4);
auto bias = m1.add_parameter("bias", s6);
auto input = m1.add_parameter("input", s7);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", ab, scale, zero);
auto d4 = add_quantize_op(m1, "dequantizelinear", db, scale, zero);
auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero);
auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto c1 = m1.add_instruction(migraphx::make_op("convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
d5,
d1);
auto bc1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap =
m1.add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"ceil_mode", 0}}),
a1);
auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero);
auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d8, d4);
auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto mb1 =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m1.add_instruction(migraphx::make_op("add"), d9, mb1);
m1.add_return({a2});
}
migraphx::module m2;
{
auto db = m2.add_parameter("db", s2); // dot input b
auto ab = m2.add_parameter("ab", s3); // add input b
auto weights = m2.add_parameter("weights", s4);
auto bias = m2.add_parameter("bias", s6);
auto input = m2.add_parameter("input", s7);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto scale2 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero);
auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q1,
weights);
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto bc1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap =
m2.add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"ceil_mode", 0}}),
a1);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db);
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2);
auto mb1 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1);
m2.add_return({a2});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(mobilenet_snippet)
{
migraphx::shape s2{migraphx::shape::int8_type, {1280, 1000}};
migraphx::shape s3{migraphx::shape::int8_type, {1000}};
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
migraphx::shape s6{migraphx::shape::int32_type, {1280}};
migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}};
auto create_module = [&]() {
migraphx::module mm;
auto db = mm.add_parameter("db", s2); // dot input b
auto ab = mm.add_parameter("ab", s3); // add input b
auto weights = mm.add_parameter("weights", s4);
auto bias = mm.add_parameter("bias", s6);
auto input = mm.add_parameter("input", s7);
auto scale = mm.add_literal(0.5f);
auto zero = mm.add_literal(std::int8_t{0});
auto d1 = add_quantize_op(mm, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(mm, "dequantizelinear", bias, scale, zero);
auto d3 = add_quantize_op(mm, "dequantizelinear", ab, scale, zero);
auto d4 = add_quantize_op(mm, "dequantizelinear", db, scale, zero);
auto q1 = add_quantize_op(mm, "quantizelinear", input, scale, zero);
auto d5 = add_quantize_op(mm, "dequantizelinear", q1, scale, zero);
auto c1 = mm.add_instruction(migraphx::make_op("convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
d5,
d1);
auto bc1 = mm.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
auto ap =
mm.add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"ceil_mode", 0}}),
d6);
auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero);
auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero);
auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7);
auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero);
auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero);
auto dot = mm.add_instruction(migraphx::make_op("dot"), d8, d4);
auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto mb1 =
mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = mm.add_instruction(migraphx::make_op("add"), d9, mb1);
mm.add_return({a2});
return mm;
};
auto mod1 = create_module();
auto mod2 = create_module();
run_pass(mod2);
auto match_qdq = migraphx::match::name("dequantizelinear")(
migraphx::match::arg(0)(migraphx::match::name("quantizelinear")));
auto ins1 = migraphx::match::find_match(mod1, match_qdq);
auto ins2 = migraphx::match::find_match(mod2, match_qdq);
EXPECT((ins1.result != mod1.end()) and (ins2.result == mod2.end()));
EXPECT(any_of(mod1, &is_convolution));
EXPECT(none_of(mod2, &is_convolution));
EXPECT(any_of(mod1, &is_dot));
EXPECT(none_of(mod2, &is_dot));
}
TEST_CASE(conv_correctness)
{
migraphx::shape si{migraphx::shape::float_type, {2, 3, 4, 4}};
migraphx::shape sw{migraphx::shape::int8_type, {2, 3, 3, 3}};
migraphx::program p1;
{
auto* m1 = p1.get_main_module();
auto input = m1->add_parameter("input", si);
auto weights = m1->add_parameter("weights", sw);
auto scale_i = m1->add_literal(0.5f);
auto scale_w = m1->add_literal(0.1f);
auto zero = m1->add_literal(std::int8_t{0});
auto d1 = add_quantize_op(*m1, "dequantizelinear", weights, scale_w, zero);
auto q1 = add_quantize_op(*m1, "quantizelinear", input, scale_i, zero);
auto d5 = add_quantize_op(*m1, "dequantizelinear", q1, scale_i, zero);
auto c1 = m1->add_instruction(migraphx::make_op("convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
d5,
d1);
m1->add_return({c1});
run_pass(*m1);
}
migraphx::program p2;
{
auto* m2 = p2.get_main_module();
auto input = m2->add_parameter("input", si);
auto weights = m2->add_parameter("weights", sw);
auto scale = m2->add_literal(0.1f);
auto zero = m2->add_literal(std::int8_t{0});
auto d1 = add_quantize_op(*m2, "dequantizelinear", weights, scale, zero);
auto c1 = m2->add_instruction(migraphx::make_op("convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
input,
d1);
m2->add_return({c1});
}
std::vector<float> iv(si.elements(), 4);
auto input = migraphx::argument(si, iv.data());
std::vector<float> wv(sw.elements(), 10);
auto weights = migraphx::argument(sw, wv.data());
p1.compile(migraphx::target(migraphx::ref::target{}));
p2.compile(migraphx::target(migraphx::ref::target{}));
auto result1 = p1.eval({{"input", input}, {"weights", weights}}).back();
std::vector<float> rv1(16);
result1.visit([&](auto output) { rv1.assign(output.begin(), output.end()); });
auto result2 = p2.eval({{"input", input}, {"weights", weights}}).back();
std::vector<float> rv2(16);
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(rv1, rv2));
}
TEST_CASE(dot_correctness)
{
migraphx::shape sh1{migraphx::shape::float_type, {10, 4}};
migraphx::shape sh2{migraphx::shape::float_type, {4, 12}};
migraphx::shape sh3{migraphx::shape::float_type, {10, 12}};
migraphx::program p1;
{
auto* m1 = p1.get_main_module();
auto a = m1->add_parameter("a", sh1);
auto b = m1->add_parameter("b", sh2);
auto scale_a = m1->add_literal(0.4f);
auto scale_b = m1->add_literal(0.5f);
auto zero = m1->add_literal(std::int8_t{0});
auto q1 = add_quantize_op(*m1, "quantizelinear", a, scale_a, zero);
auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero);
auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero);
auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero);
auto dot = m1->add_instruction(migraphx::make_op("dot"), d1, d2);
m1->add_return({dot});
run_pass(*m1);
}
migraphx::program p2;
{
auto* m2 = p2.get_main_module();
auto a = m2->add_parameter("a", sh1);
auto b = m2->add_parameter("b", sh2);
auto dot = m2->add_instruction(migraphx::make_op("dot"), a, b);
m2->add_return({dot});
}
std::vector<float> av(sh1.elements(), 10);
auto a = migraphx::argument(sh1, av.data());
std::vector<float> bv(sh2.elements(), 10);
auto b = migraphx::argument(sh2, bv.data());
p1.compile(migraphx::target(migraphx::ref::target{}));
p2.compile(migraphx::target(migraphx::ref::target{}));
auto result1 = p1.eval({{"a", a}, {"b", b}}).back();
std::vector<float> rv1(sh3.elements());
result1.visit([&](auto output) { rv1.assign(output.begin(), output.end()); });
auto result2 = p2.eval({{"a", a}, {"b", b}}).back();
std::vector<float> rv2(sh3.elements());
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(rv1, rv2));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -22,7 +22,7 @@ TEST_CASE(double_contig) ...@@ -22,7 +22,7 @@ TEST_CASE(double_contig)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1); auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1);
auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), c1); auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), c1);
mm->add_return({c2}); mm->add_return({c2});
...@@ -42,8 +42,8 @@ TEST_CASE(double_transpose) ...@@ -42,8 +42,8 @@ TEST_CASE(double_transpose)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t1); auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t1);
mm->add_return({t2}); mm->add_return({t2});
EXPECT(mm->get_output_shapes().back().standard()); EXPECT(mm->get_output_shapes().back().standard());
EXPECT(not mm->get_output_shapes().back().transposed()); EXPECT(not mm->get_output_shapes().back().transposed());
...@@ -61,9 +61,9 @@ TEST_CASE(double_transpose_contig) ...@@ -61,9 +61,9 @@ TEST_CASE(double_transpose_contig)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1); auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1);
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), c1); auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), c1);
auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), t2); auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), t2);
mm->add_return({c2}); mm->add_return({c2});
EXPECT(mm->get_output_shapes().back().standard()); EXPECT(mm->get_output_shapes().back().standard());
...@@ -82,7 +82,7 @@ TEST_CASE(single_transpose) ...@@ -82,7 +82,7 @@ TEST_CASE(single_transpose)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
mm->add_return({t1}); mm->add_return({t1});
EXPECT(not mm->get_output_shapes().back().standard()); EXPECT(not mm->get_output_shapes().back().standard());
EXPECT(mm->get_output_shapes().back().transposed()); EXPECT(mm->get_output_shapes().back().transposed());
...@@ -100,8 +100,8 @@ TEST_CASE(double_transpose_sin_pass) ...@@ -100,8 +100,8 @@ TEST_CASE(double_transpose_sin_pass)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t1);
EXPECT(mm->get_output_shapes().back().standard()); EXPECT(mm->get_output_shapes().back().standard());
EXPECT(not mm->get_output_shapes().back().transposed()); EXPECT(not mm->get_output_shapes().back().transposed());
run_pass(*mm); run_pass(*mm);
...@@ -119,7 +119,7 @@ TEST_CASE(single_transpose_sin_pass) ...@@ -119,7 +119,7 @@ TEST_CASE(single_transpose_sin_pass)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2()); auto l = mm->add_literal(get_2x2());
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
EXPECT(not mm->get_output_shapes().back().standard()); EXPECT(not mm->get_output_shapes().back().standard());
EXPECT(mm->get_output_shapes().back().transposed()); EXPECT(mm->get_output_shapes().back().transposed());
run_pass(*mm); run_pass(*mm);
...@@ -137,7 +137,8 @@ TEST_CASE(reshape_transpose) ...@@ -137,7 +137,8 @@ TEST_CASE(reshape_transpose)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto r1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 28, 56, 56}}}), x); auto r1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 28, 56, 56}}}), x);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1, 3, 4}}}), r1); auto t =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3, 4}}}), r1);
auto ct = m.add_instruction(migraphx::make_op("contiguous"), t); auto ct = m.add_instruction(migraphx::make_op("contiguous"), t);
auto r2 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 112, 56, 56}}}), ct); auto r2 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 112, 56, 56}}}), ct);
m.add_return({r2}); m.add_return({r2});
...@@ -154,7 +155,7 @@ TEST_CASE(transpose_contiguous) ...@@ -154,7 +155,7 @@ TEST_CASE(transpose_contiguous)
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), x); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x);
auto c1 = m.add_instruction(migraphx::make_op("contiguous"), t); auto c1 = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_return({c1}); m.add_return({c1});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
...@@ -170,7 +171,7 @@ TEST_CASE(transpose_double_contiguous) ...@@ -170,7 +171,7 @@ TEST_CASE(transpose_double_contiguous)
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), x); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x);
auto c1 = m.add_instruction(migraphx::make_op("contiguous"), t); auto c1 = m.add_instruction(migraphx::make_op("contiguous"), t);
auto c2 = m.add_instruction(migraphx::make_op("contiguous"), c1); auto c2 = m.add_instruction(migraphx::make_op("contiguous"), c1);
m.add_return({c2}); m.add_return({c2});
...@@ -188,8 +189,8 @@ TEST_CASE(transpose_partial1) ...@@ -188,8 +189,8 @@ TEST_CASE(transpose_partial1)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x); auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), x);
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1); auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), t1);
m.add_return({t2}); m.add_return({t2});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -204,9 +205,9 @@ TEST_CASE(transpose_partial2) ...@@ -204,9 +205,9 @@ TEST_CASE(transpose_partial2)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x); auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), x);
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1); auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), t1);
auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t2); auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), t2);
m.add_return({t3}); m.add_return({t3});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -221,10 +222,10 @@ TEST_CASE(transpose_partial3) ...@@ -221,10 +222,10 @@ TEST_CASE(transpose_partial3)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x); auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), x);
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1); auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), t1);
auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t2); auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), t2);
auto t4 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t3); auto t4 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), t3);
m.add_return({t4}); m.add_return({t4});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -239,7 +240,7 @@ TEST_CASE(nop_transpose1) ...@@ -239,7 +240,7 @@ TEST_CASE(nop_transpose1)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), x); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), x);
m.add_return({t}); m.add_return({t});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -254,10 +255,10 @@ TEST_CASE(nop_transpose2) ...@@ -254,10 +255,10 @@ TEST_CASE(nop_transpose2)
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), x); auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), x);
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t1); auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), t1);
auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t2); auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), t2);
auto t4 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t3); auto t4 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), t3);
m.add_instruction(pass_op{}, t4); m.add_instruction(pass_op{}, t4);
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -274,8 +275,10 @@ TEST_CASE(nop_transpose3) ...@@ -274,8 +275,10 @@ TEST_CASE(nop_transpose3)
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto y = m.add_parameter("y", s); auto y = m.add_parameter("y", s);
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), x, y); auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), x, y);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2, 3}}}), concat); auto t1 =
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t1); m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), concat);
auto t2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t1);
m.add_return({t2}); m.add_return({t2});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -306,13 +309,14 @@ TEST_CASE(concat_transpose1) ...@@ -306,13 +309,14 @@ TEST_CASE(concat_transpose1)
{ {
migraphx::module m; migraphx::module m;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto y = m.add_parameter("y", s); auto y = m.add_parameter("y", s);
auto xt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x); auto xt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), x);
auto yt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), y); auto yt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), y);
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), xt, yt); auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), xt, yt);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), concat); auto t =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), concat);
m.add_return({t}); m.add_return({t});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -329,13 +333,14 @@ TEST_CASE(concat_transpose2) ...@@ -329,13 +333,14 @@ TEST_CASE(concat_transpose2)
{ {
migraphx::module m; migraphx::module m;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto y = m.add_parameter("y", s); auto y = m.add_parameter("y", s);
auto xt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x); auto xt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto yt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), y); auto yt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), y);
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", -1}}), xt, yt); auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", -1}}), xt, yt);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat); auto t =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), concat);
m.add_return({t}); m.add_return({t});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -352,13 +357,14 @@ TEST_CASE(concat_transpose3) ...@@ -352,13 +357,14 @@ TEST_CASE(concat_transpose3)
{ {
migraphx::module m; migraphx::module m;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}); auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}});
auto y = m.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}}); auto y = m.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}});
auto xt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x); auto xt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto yt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), y); auto yt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), y);
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt); auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat); auto t =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), concat);
m.add_return({t}); m.add_return({t});
auto out_shape = m.get_output_shapes().back(); auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
...@@ -374,14 +380,15 @@ TEST_CASE(concat_transpose3) ...@@ -374,14 +380,15 @@ TEST_CASE(concat_transpose3)
TEST_CASE(concat_transpose4) TEST_CASE(concat_transpose4)
{ {
migraphx::module m; migraphx::module m;
auto sx = migraphx::shape{migraphx::shape::float_type, {1, 1, 12, 64}}; auto sx = migraphx::shape{migraphx::shape::float_type, {1, 1, 12, 64}};
auto sy = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}}; auto sy = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}};
auto x = m.add_parameter("x", sx); auto x = m.add_parameter("x", sx);
auto y = m.add_parameter("y", sy); auto y = m.add_parameter("y", sy);
auto xt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x); auto xt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto yt = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), y); auto yt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), y);
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt); auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat); auto t =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), concat);
m.add_return({t}); m.add_return({t});
migraphx::module m1 = m; migraphx::module m1 = m;
...@@ -438,7 +445,7 @@ TEST_CASE(multibroadcast_simplify) ...@@ -438,7 +445,7 @@ TEST_CASE(multibroadcast_simplify)
std::vector<size_t> s_lens{1, 2, 3, 4}; std::vector<size_t> s_lens{1, 2, 3, 4};
auto s = migraphx::shape{migraphx::shape::float_type, s_lens}; auto s = migraphx::shape{migraphx::shape::float_type, s_lens};
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto y = m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s_lens}}), x); auto y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_lens}}), x);
m.add_instruction(migraphx::make_op("mul"), y, y); m.add_instruction(migraphx::make_op("mul"), y, y);
auto n = std::distance(m.begin(), m.end()); auto n = std::distance(m.begin(), m.end());
run_pass(m); run_pass(m);
...@@ -547,8 +554,8 @@ TEST_CASE(optimize_resize) ...@@ -547,8 +554,8 @@ TEST_CASE(optimize_resize)
std::vector<int64_t> dims = {1, 1, 2, 1, 2, 1}; std::vector<int64_t> dims = {1, 1, 2, 1, 2, 1};
auto rspx = m.add_instruction(migraphx::make_op("reshape", {{"dims", dims}}), inx); auto rspx = m.add_instruction(migraphx::make_op("reshape", {{"dims", dims}}), inx);
std::vector<int64_t> mb_dims = {1, 2, 2, 2, 2, 3}; std::vector<int64_t> mb_dims = {1, 2, 2, 2, 2, 3};
auto mbx = m.add_instruction( auto mbx =
migraphx::make_op("multibroadcast", {{"output_lens", mb_dims}}), rspx); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_dims}}), rspx);
auto std_mb = m.add_instruction(migraphx::make_op("contiguous"), mbx); auto std_mb = m.add_instruction(migraphx::make_op("contiguous"), mbx);
std::vector<int64_t> orig_dims = {1, 2, 4, 6}; std::vector<int64_t> orig_dims = {1, 2, 4, 6};
auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", orig_dims}}), std_mb); auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", orig_dims}}), std_mb);
...@@ -707,20 +714,21 @@ TEST_CASE(optimize_where_true) ...@@ -707,20 +714,21 @@ TEST_CASE(optimize_where_true)
return m; return m;
}; };
auto create_opt_module = [&](std::string name) { auto return_xy = [&](bool cond) {
migraphx::module m; migraphx::module m;
auto in = m.add_parameter(std::move(name), s); auto x = m.add_parameter("X", s);
m.add_return({in}); auto y = m.add_parameter("Y", s);
cond ? m.add_return({x}) : m.add_return({y});
return m; return m;
}; };
auto m = create_where_module(true); auto m = create_where_module(true);
run_pass(m); run_pass(m);
EXPECT(m == create_opt_module("X")); EXPECT(m == return_xy(true));
auto m1 = create_where_module(false); auto m1 = create_where_module(false);
run_pass(m1); run_pass(m1);
EXPECT(m1 == create_opt_module("Y")); EXPECT(m1 == return_xy(false));
} }
TEST_CASE(where_different_cond_values) TEST_CASE(where_different_cond_values)
...@@ -847,10 +855,10 @@ TEST_CASE(reshape_cont) ...@@ -847,10 +855,10 @@ TEST_CASE(reshape_cont)
migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}}; migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 2, 6}}; migraphx::shape sy{migraphx::shape::float_type, {2, 2, 2, 6}};
auto inx = m.add_parameter("x", sx); auto inx = m.add_parameter("x", sx);
auto iny = m.add_parameter("y", sy); auto iny = m.add_parameter("y", sy);
auto mb_inx = m.add_instruction( auto mb_inx =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 4, 6}}}), inx); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), inx);
auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx); auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx);
auto rsp = auto rsp =
m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx); m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx);
...@@ -868,10 +876,10 @@ TEST_CASE(reshape_cont) ...@@ -868,10 +876,10 @@ TEST_CASE(reshape_cont)
migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}}; migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 2, 6}}; migraphx::shape sy{migraphx::shape::float_type, {2, 2, 2, 6}};
auto inx = m.add_parameter("x", sx); auto inx = m.add_parameter("x", sx);
auto iny = m.add_parameter("y", sy); auto iny = m.add_parameter("y", sy);
auto mb_inx = m.add_instruction( auto mb_inx =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 4, 6}}}), inx); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), inx);
auto rsp_iny = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4, 6}}}), iny); auto rsp_iny = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4, 6}}}), iny);
auto sum = m.add_instruction(migraphx::make_op("add"), mb_inx, rsp_iny); auto sum = m.add_instruction(migraphx::make_op("add"), mb_inx, rsp_iny);
auto r = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), sum); auto r = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), sum);
...@@ -890,15 +898,16 @@ TEST_CASE(reshape_input_non_std) ...@@ -890,15 +898,16 @@ TEST_CASE(reshape_input_non_std)
migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}}; migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}};
migraphx::shape sy{migraphx::shape::float_type, {2, 6, 2, 2}}; migraphx::shape sy{migraphx::shape::float_type, {2, 6, 2, 2}};
auto inx = m.add_parameter("x", sx); auto inx = m.add_parameter("x", sx);
auto iny = m.add_parameter("y", sy); auto iny = m.add_parameter("y", sy);
auto mb_inx = m.add_instruction( auto mb_inx =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 4, 6}}}), inx); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), inx);
auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx); auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx);
auto rsp = auto rsp =
m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx); m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx);
auto ty = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), iny); auto ty =
auto r = m.add_instruction(migraphx::make_op("add"), rsp, ty); m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), iny);
auto r = m.add_instruction(migraphx::make_op("add"), rsp, ty);
m.add_return({r}); m.add_return({r});
return m; return m;
...@@ -917,10 +926,10 @@ TEST_CASE(reshape_cont_nonpw) ...@@ -917,10 +926,10 @@ TEST_CASE(reshape_cont_nonpw)
migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}}; migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 2, 6}}; migraphx::shape sy{migraphx::shape::float_type, {2, 2, 2, 6}};
auto inx = m.add_parameter("x", sx); auto inx = m.add_parameter("x", sx);
auto iny = m.add_parameter("y", sy); auto iny = m.add_parameter("y", sy);
auto mb_inx = m.add_instruction( auto mb_inx =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 4, 6}}}), inx); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), inx);
auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx); auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx);
auto rsp = auto rsp =
m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx); m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx);
...@@ -936,4 +945,177 @@ TEST_CASE(reshape_cont_nonpw) ...@@ -936,4 +945,177 @@ TEST_CASE(reshape_cont_nonpw)
EXPECT(m1 == create_module()); EXPECT(m1 == create_module());
} }
TEST_CASE(transpose_contiguous_reshape_unary)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto reshape_ins1 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto reshape_ins2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins2);
m1.add_instruction(pass_op{}, relu);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto reshape_ins1 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), relu);
auto reshape_ins2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins);
m2.add_instruction(pass_op{}, reshape_ins2);
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_contiguous_squeeze_unary)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}});
auto transpose_ins =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto sq_ins = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins);
auto rsqrt = m1.add_instruction(migraphx::make_op("rsqrt"), sq_ins);
m1.add_instruction(pass_op{}, rsqrt);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}});
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), rsqrt);
auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins);
m2.add_instruction(pass_op{}, sq_ins);
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_contiguous_unsqueeze_unary)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto transpose_ins =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto unsq_ins =
m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins);
auto round = m1.add_instruction(migraphx::make_op("round"), unsq_ins);
m1.add_instruction(pass_op{}, round);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto round = m2.add_instruction(migraphx::make_op("round"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), round);
auto unsq_ins =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins);
m2.add_instruction(pass_op{}, unsq_ins);
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_contiguous_reshape_binary_packed)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 128, 28, 28}});
auto w1 = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = m1.add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
x,
w1); // (2, 256, 28, 28)
auto w2 = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {512, 256, 1, 1}}));
auto conv2 = m1.add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
conv1,
w2); // (2, 512, 14, 14)
auto conv2_rsp1 = m1.add_instruction(
migraphx::make_op("reshape", {{"dims", {2, 2, 2, 128, 14, 14}}}), conv2);
auto conv2_trans = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), conv2_rsp1);
auto conv2_cont = m1.add_instruction(migraphx::make_op("contiguous"), conv2_trans);
auto conv2_rsp2 = m1.add_instruction(
migraphx::make_op("reshape", {{"dims", {2, 128, 28, 28}}}), conv2_cont);
auto add_ins = m1.add_instruction(migraphx::make_op("add"), conv2_rsp2, x);
m1.add_instruction(pass_op{}, add_ins);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 128, 28, 28}});
auto w1 = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = m2.add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
x,
w1); // (2, 256, 28, 28)
auto w2 = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {512, 256, 1, 1}}));
auto conv2 = m2.add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
conv1,
w2); // (2, 512, 14, 14)
auto conv2_rsp = m2.add_instruction(
migraphx::make_op("reshape", {{"dims", {2, 2, 2, 128, 14, 14}}}), conv2);
auto conv2_trans = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), conv2_rsp);
auto x_rsp =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 128, 14, 2, 14, 2}}}), x);
auto add_ins = m2.add_instruction(migraphx::make_op("add"), conv2_trans, x_rsp);
auto add_rsp =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 128, 28, 28}}}), add_ins);
m2.add_instruction(pass_op{}, add_rsp);
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
{
migraphx::module m1;
{
migraphx::shape sx{migraphx::shape::float_type, {4}};
migraphx::shape sy{migraphx::shape::float_type, {2, 6, 2, 2}};
auto x = m1.add_parameter("x", sx);
auto y = m1.add_parameter("y", sy);
auto x_brcst = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 4, 6}}}), x);
auto y_trans =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), y);
auto y_cont = m1.add_instruction(migraphx::make_op("contiguous"), y_trans);
auto y_rsp =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4, 6}}}), y_cont);
auto r = m1.add_instruction(migraphx::make_op("add"), y_rsp, x_brcst);
m1.add_return({r});
}
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/stringutils.hpp>
#include <test.hpp>
TEST_CASE(interpolate_string_simple1)
{
std::string input = "Hello ${w}!";
auto s = migraphx::interpolate_string(input, {{"w", "world"}});
EXPECT(s == "Hello world!");
}
TEST_CASE(interpolate_string_simple2)
{
std::string input = "${hello}";
auto s = migraphx::interpolate_string(input, {{"hello", "bye"}});
EXPECT(s == "bye");
}
TEST_CASE(interpolate_string_unbalanced)
{
std::string input = "${hello";
EXPECT(test::throws([&] { migraphx::interpolate_string(input, {{"hello", "bye"}}); }));
}
TEST_CASE(interpolate_string_extra_space)
{
std::string input = "${ hello }";
auto s = migraphx::interpolate_string(input, {{"hello", "bye"}});
EXPECT(s == "bye");
}
TEST_CASE(interpolate_string_multiple)
{
std::string input = "${h} ${w}!";
auto s = migraphx::interpolate_string(input, {{"w", "world"}, {"h", "Hello"}});
EXPECT(s == "Hello world!");
}
TEST_CASE(interpolate_string_next)
{
std::string input = "${hh}${ww}!";
auto s = migraphx::interpolate_string(input, {{"ww", "world"}, {"hh", "Hello"}});
EXPECT(s == "Helloworld!");
}
TEST_CASE(interpolate_string_dollar_sign)
{
std::string input = "$hello";
auto s = migraphx::interpolate_string(input, {{"hello", "bye"}});
EXPECT(s == "$hello");
}
TEST_CASE(interpolate_string_missing)
{
std::string input = "${hello}";
EXPECT(test::throws([&] { migraphx::interpolate_string(input, {{"h", "bye"}}); }));
}
TEST_CASE(interpolate_string_custom1)
{
std::string input = "****{{a}}****";
auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{", "}}");
EXPECT(s == "****b****");
}
TEST_CASE(interpolate_string_custom2)
{
std::string input = "****{{{a}}}****";
auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{{", "}}}");
EXPECT(s == "****b****");
}
TEST_CASE(interpolate_string_custom3)
{
std::string input = "****{{{{a}}}}****";
auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{{{", "}}}}");
EXPECT(s == "****b****");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -12,6 +12,11 @@ TEST_CASE(make_target) ...@@ -12,6 +12,11 @@ TEST_CASE(make_target)
} }
} }
TEST_CASE(make_invalid_target)
{
EXPECT(test::throws([&] { migraphx::make_target("mi100"); }));
}
TEST_CASE(targets) TEST_CASE(targets)
{ {
auto ts = migraphx::get_targets(); auto ts = migraphx::get_targets();
......
...@@ -87,7 +87,7 @@ TEST_CASE(add_bcast_test) ...@@ -87,7 +87,7 @@ TEST_CASE(add_bcast_test)
auto l0 = mm->add_parameter("0", s0); auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
auto l2 = auto l2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l2); mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_tf("add_bcast_test.pb", false); auto prog = optimize_tf("add_bcast_test.pb", false);
...@@ -151,9 +151,9 @@ TEST_CASE(batchmatmul_test) ...@@ -151,9 +151,9 @@ TEST_CASE(batchmatmul_test)
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}});
auto trans_l0 = auto trans_l0 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l0);
auto trans_l1 = auto trans_l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1); mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
auto prog = optimize_tf("batchmatmul_test.pb", false); auto prog = optimize_tf("batchmatmul_test.pb", false);
...@@ -220,7 +220,7 @@ TEST_CASE(biasadd_test) ...@@ -220,7 +220,7 @@ TEST_CASE(biasadd_test)
auto l0 = mm->add_parameter("0", s0); auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = mm->add_instruction( auto l2 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l2); mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_tf("biasadd_test.pb", true); auto prog = optimize_tf("biasadd_test.pb", true);
...@@ -238,7 +238,7 @@ TEST_CASE(biasadd_scalar_test) ...@@ -238,7 +238,7 @@ TEST_CASE(biasadd_scalar_test)
auto l1 = mm->add_literal( auto l1 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}});
auto l2 = mm->add_instruction( auto l2 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l0->get_shape().lens()}}), l1); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l2); mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_tf("biasadd_scalar_test.pb", true); auto prog = optimize_tf("biasadd_scalar_test.pb", true);
...@@ -308,7 +308,8 @@ migraphx::program create_conv() ...@@ -308,7 +308,8 @@ migraphx::program create_conv()
op.padding = {1, 1, 1, 1}; op.padding = {1, 1, 1, 1};
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1); auto l2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1);
mm->add_instruction(op, l0, l2); mm->add_instruction(op, l0, l2);
return p; return p;
} }
...@@ -359,10 +360,10 @@ TEST_CASE(conv_relu6_test) ...@@ -359,10 +360,10 @@ TEST_CASE(conv_relu6_test)
auto l0 = std::prev(mm->end()); auto l0 = std::prev(mm->end());
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction( min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val); min_val);
max_val = mm->add_instruction( max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val); max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
auto prog = optimize_tf("conv_relu6_test.pb", true); auto prog = optimize_tf("conv_relu6_test.pb", true);
...@@ -387,7 +388,8 @@ TEST_CASE(depthwiseconv_test) ...@@ -387,7 +388,8 @@ TEST_CASE(depthwiseconv_test)
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
op.group = 3; op.group = 3;
auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1); auto l3 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1);
auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3); auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3);
auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4); auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4);
mm->add_instruction(op, l0, l5); mm->add_instruction(op, l0, l5);
...@@ -463,8 +465,10 @@ TEST_CASE(matmul_test) ...@@ -463,8 +465,10 @@ TEST_CASE(matmul_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
auto trans_l0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l0); auto trans_l0 =
auto trans_l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l0);
auto trans_l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1); mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1);
auto prog = optimize_tf("matmul_test.pb", false); auto prog = optimize_tf("matmul_test.pb", false);
...@@ -497,7 +501,8 @@ TEST_CASE(mean_test_nhwc) ...@@ -497,7 +501,8 @@ TEST_CASE(mean_test_nhwc)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
migraphx::op::reduce_mean op{{1, 2}}; migraphx::op::reduce_mean op{{1, 2}};
auto l2 = mm->add_instruction(op, l1); auto l2 = mm->add_instruction(op, l1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2);
...@@ -595,11 +600,14 @@ TEST_CASE(pack_test_nhwc) ...@@ -595,11 +600,14 @@ TEST_CASE(pack_test_nhwc)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto lt0 =
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto lt1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l1); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto lt1 =
auto lt2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l1);
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2}; std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args; std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 3; int64_t nchw_axis = 3;
...@@ -641,8 +649,8 @@ TEST_CASE(pooling_test) ...@@ -641,8 +649,8 @@ TEST_CASE(pooling_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::op::pooling avg_pool_op{"average"}; migraphx::op::pooling avg_pool_op{migraphx::op::pooling_mode::average};
migraphx::op::pooling max_pool_op{"max"}; migraphx::op::pooling max_pool_op{migraphx::op::pooling_mode::max};
avg_pool_op.stride = {2, 2}; avg_pool_op.stride = {2, 2};
max_pool_op.stride = {2, 2}; max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2}; avg_pool_op.lengths = {2, 2};
...@@ -688,10 +696,10 @@ TEST_CASE(relu6_test) ...@@ -688,10 +696,10 @@ TEST_CASE(relu6_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction( min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val); min_val);
max_val = mm->add_instruction( max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val); max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
auto prog = optimize_tf("relu6_test.pb", false); auto prog = optimize_tf("relu6_test.pb", false);
...@@ -882,7 +890,8 @@ TEST_CASE(stridedslice_test) ...@@ -882,7 +890,8 @@ TEST_CASE(stridedslice_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
std::size_t num_axes = 4; std::size_t num_axes = 4;
migraphx::op::slice op; migraphx::op::slice op;
op.starts = {0, 0, 0, 0}; op.starts = {0, 0, 0, 0};
...@@ -917,9 +926,11 @@ TEST_CASE(stridedslice_masks_test) ...@@ -917,9 +926,11 @@ TEST_CASE(stridedslice_masks_test)
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
std::vector<int>{1, 1, 1, 1}); std::vector<int>{1, 1, 1, 1});
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l2 = mm->add_instruction(op, l1); auto l2 = mm->add_instruction(op, l1);
auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), l2); auto l3 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2);
mm->add_return({l3}); mm->add_return({l3});
auto prog = parse_tf("stridedslice_masks_test.pb", true); auto prog = parse_tf("stridedslice_masks_test.pb", true);
...@@ -961,7 +972,7 @@ TEST_CASE(transpose_test) ...@@ -961,7 +972,7 @@ TEST_CASE(transpose_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::shape s0{migraphx::shape::int32_type, {4}}; migraphx::shape s0{migraphx::shape::int32_type, {4}};
mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}}); mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto prog = optimize_tf("transpose_test.pb", false); auto prog = optimize_tf("transpose_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
......
...@@ -57,6 +57,15 @@ TEST_CASE(value_construct_string) ...@@ -57,6 +57,15 @@ TEST_CASE(value_construct_string)
EXPECT(v.get_key().empty()); EXPECT(v.get_key().empty());
} }
TEST_CASE(value_construct_key_string_literal_pair)
{
// Use parens instead {} to construct to test the key-pair constructor
migraphx::value v("key", "one");
EXPECT(v.is_string());
EXPECT(v.get_string() == "one");
EXPECT(v.get_key() == "key");
}
TEST_CASE(value_construct_float) TEST_CASE(value_construct_float)
{ {
migraphx::value v = 1.0; migraphx::value v = 1.0;
...@@ -167,6 +176,15 @@ TEST_CASE(value_copy_assign_keyless) ...@@ -167,6 +176,15 @@ TEST_CASE(value_copy_assign_keyless)
EXPECT(v1.without_key() == v2.without_key()); EXPECT(v1.without_key() == v2.without_key());
} }
TEST_CASE(value_assign_key_string_literal_pair)
{
migraphx::value v = migraphx::value::object{};
v["key"] = "one";
EXPECT(v["key"].is_string());
EXPECT(v["key"].get_string() == "one");
EXPECT(v["key"].get_key() == "key");
}
TEST_CASE(value_construct_array) TEST_CASE(value_construct_array)
{ {
migraphx::value v = {1, 2, 3}; migraphx::value v = {1, 2, 3};
...@@ -522,6 +540,14 @@ TEST_CASE(value_construct_object_string_mixed_value) ...@@ -522,6 +540,14 @@ TEST_CASE(value_construct_object_string_mixed_value)
EXPECT(v.at("two").get_int64() == 2); EXPECT(v.at("two").get_int64() == 2);
} }
template <class Expression>
auto compare_predicate(const Expression& e)
{
bool result = e.value();
return test::make_predicate(test::as_string(e) + " => " + test::as_string(result),
[=] { return result; });
}
TEST_CASE(value_compare) TEST_CASE(value_compare)
{ {
EXPECT(migraphx::value(1) == migraphx::value(1)); EXPECT(migraphx::value(1) == migraphx::value(1));
...@@ -535,6 +561,46 @@ TEST_CASE(value_compare) ...@@ -535,6 +561,46 @@ TEST_CASE(value_compare)
EXPECT(migraphx::value(2) > migraphx::value(1)); EXPECT(migraphx::value(2) > migraphx::value(1));
EXPECT(migraphx::value(2) >= migraphx::value(1)); EXPECT(migraphx::value(2) >= migraphx::value(1));
EXPECT(migraphx::value(1) >= migraphx::value(1)); EXPECT(migraphx::value(1) >= migraphx::value(1));
EXPECT(migraphx::value(1) != migraphx::value("1"));
EXPECT(migraphx::value(1) != migraphx::value());
}
// NOLINTNEXTLINE
#define MIGRAPHX_VALUE_TEST_COMPARE(...) compare_predicate(TEST_CAPTURE(__VA_ARGS__))
// NOLINTNEXTLINE
#define EXPECT_TOTALLY_ORDERED_IMPL(_, x, y) \
EXPECT(_(x <= y) or _(x >= y)); \
EXPECT(_(x < y) or _(x > y) or _(x == y)); \
EXPECT((_(x < y) or _(x > y)) == _(x != y)); \
EXPECT(_(x < y) == _(y > x)); \
EXPECT(_(x <= y) == _(y >= x)); \
EXPECT(_(x < y) != _(x >= y)); \
EXPECT(_(x > y) != _(x <= y)); \
EXPECT(_(x == y) != _(x != y))
// NOLINTNEXTLINE
#define EXPECT_TOTALLY_ORDERED(x, y) \
EXPECT_TOTALLY_ORDERED_IMPL(MIGRAPHX_VALUE_TEST_COMPARE, x, y); \
EXPECT_TOTALLY_ORDERED_IMPL(MIGRAPHX_VALUE_TEST_COMPARE, y, x)
// NOLINTNEXTLINE(readability-function-size)
TEST_CASE(value_compare_ordered)
{
EXPECT_TOTALLY_ORDERED(migraphx::value(), migraphx::value());
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value(1));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value(2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", 1));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", "2"));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", "2"));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{1}), migraphx::value(std::uint64_t{1}));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{1}), migraphx::value(std::uint64_t{2}));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{2}), migraphx::value(std::uint64_t{1}));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value("1"));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value());
} }
TEST_CASE(value_to_from_string) TEST_CASE(value_to_from_string)
...@@ -835,4 +901,38 @@ TEST_CASE(value_or_null) ...@@ -835,4 +901,38 @@ TEST_CASE(value_or_null)
EXPECT(v.value_or(3) == 3); EXPECT(v.value_or(3) == 3);
} }
TEST_CASE(value_get_default)
{
migraphx::value v = {{"key", 1}};
EXPECT(v.get("key", 3) == 1);
EXPECT(v.get("missing", 3) == 3);
}
TEST_CASE(value_get_default_vector)
{
std::vector<int> ints = {1, 2, 3};
std::vector<int> fallback = {-1};
migraphx::value v = {{"key", ints}};
EXPECT(v.get("key", fallback) == ints);
EXPECT(v.get("missing", fallback) == fallback);
EXPECT(v.get("missing", {-1}) == fallback);
}
TEST_CASE(value_get_default_string_literal)
{
migraphx::value v = {{"key", "hello"}};
EXPECT(v.get("key", "none") == "hello");
EXPECT(v.get("missing", "none") == "none");
}
TEST_CASE(value_get_default_string_literal_vector)
{
std::vector<std::string> strings = {"1", "2", "3"};
std::vector<std::string> fallback = {"none"};
migraphx::value v = {{"key", strings}};
EXPECT(v.get("key", fallback) == strings);
EXPECT(v.get("missing", fallback) == fallback);
EXPECT(v.get("missing", {"none"}) == fallback);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
file(GLOB VERIFY_TESTS *.cpp) file(GLOB VERIFY_TESTS ${CONFIGURE_DEPENDS} *.cpp)
add_executable(test_verify ${VERIFY_TESTS}) add_executable(test_verify ${VERIFY_TESTS})
add_dependencies(tests test_verify) add_dependencies(tests test_verify)
......
#include "auto_print.hpp" #include "auto_print.hpp"
#include <map> #include <map>
#include <exception> #include <exception>
#include <iostream>
#ifdef __clang__ #ifdef __clang__
#pragma clang diagnostic push #pragma clang diagnostic push
......
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -14,15 +15,14 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1> ...@@ -14,15 +15,14 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = auto tl1 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = auto tl2 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
return p; return p;
} }
}; };
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -17,8 +18,7 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2> ...@@ -17,8 +18,7 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2, l3);
return p; return p;
} }
}; };
...@@ -15,7 +15,7 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3> ...@@ -15,7 +15,7 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2); mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2);
return p; return p;
} }
}; };
...@@ -13,13 +13,13 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4> ...@@ -13,13 +13,13 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}}; migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl1 = auto tl1 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 0, 1, 2}}}), l1); migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), l1);
auto tl2 = auto tl2 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 1, 2, 0}}}), l2); migraphx::make_op("transpose", {{"permutation", {3, 1, 2, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, tl2); mm->add_instruction(migraphx::make_op("quant_dot"), tl1, tl2);
return p; return p;
} }
}; };
...@@ -13,15 +13,15 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5> ...@@ -13,15 +13,15 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}}; migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl1 = auto tl1 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto sl1 = mm->add_instruction(migraphx::make_op("add"), tl1, tl1); auto sl1 = mm->add_instruction(migraphx::make_op("add"), tl1, tl1);
auto tl2 = auto tl2 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2); auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}}), sl1, sl2); mm->add_instruction(migraphx::make_op("quant_dot"), sl1, sl2);
return p; return p;
} }
}; };
...@@ -16,7 +16,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv> ...@@ -16,7 +16,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2); auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
auto bul2 = mm->add_instruction( auto bul2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 5, 1}}}), ul2); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5, 1}}}), ul2);
mm->add_instruction(migraphx::make_op("dot"), l1, bul2); mm->add_instruction(migraphx::make_op("dot"), l1, bul2);
......
...@@ -12,10 +12,10 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1> ...@@ -12,10 +12,10 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction( auto bl2 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4}}}), l2);
mm->add_instruction(migraphx::make_op("dot"), l1, bl2); mm->add_instruction(migraphx::make_op("dot"), l1, bl2);
......
...@@ -12,10 +12,10 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2> ...@@ -12,10 +12,10 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction( auto bl2 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4}}}), l2);
mm->add_instruction(migraphx::make_op("dot"), l1, bl2); mm->add_instruction(migraphx::make_op("dot"), l1, bl2);
......
...@@ -12,9 +12,9 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3> ...@@ -12,9 +12,9 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction( auto bl1 =
migraphx::make_op("multibroadcast", {{"output_lens", {3, 2, 3}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 3}}}), l1);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::make_op("dot"), bl1, l2); mm->add_instruction(migraphx::make_op("dot"), bl1, l2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment