Commit c79661d6 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments

parent e35ad617
...@@ -20,7 +20,7 @@ void auto_contiguous::apply(module& p) const ...@@ -20,7 +20,7 @@ void auto_contiguous::apply(module& p) const
auto args = ins->inputs(); auto args = ins->inputs();
auto new_args = args; auto new_args = args;
std::transform(args.begin(), args.end(), new_args.begin(), [&](auto in) { std::transform(args.begin(), args.end(), new_args.begin(), [&](auto in) {
return p.replace_instruction(ins, make_op("contiguous"), in); return p.insert_instruction(ins, make_op("contiguous"), in);
}); });
if(new_args != args) if(new_args != args)
...@@ -29,6 +29,18 @@ void auto_contiguous::apply(module& p) const ...@@ -29,6 +29,18 @@ void auto_contiguous::apply(module& p) const
} }
} }
} }
auto last = std::prev(p.end());
for(auto ins : iterator_for(p))
{
if (ins->outputs().empty() and ins != last) continue;
shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0)
{
auto c = p.insert_instruction(std::next(ins), make_op("contiguous"), ins);
p.replace_instruction(ins, c);
}
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -66,7 +66,7 @@ struct reduce_op : op_name<Derived> ...@@ -66,7 +66,7 @@ struct reduce_op : op_name<Derived>
{ {
value normalize; value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min}; normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}}; return {{"normalize_axes", normalize}, {"standard_input_shape", true}};
} }
std::vector<int64_t> tune_axes(std::size_t n_dim) const std::vector<int64_t> tune_axes(std::size_t n_dim) const
......
...@@ -26,6 +26,7 @@ const auto& reshaper_names() ...@@ -26,6 +26,7 @@ const auto& reshaper_names()
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
"flatten", "flatten",
"reshape", "reshape",
"contiguous",
"squeeze", "squeeze",
"unsqueeze" "unsqueeze"
}; };
......
...@@ -1264,10 +1264,8 @@ TEST_CASE(flatten_test) ...@@ -1264,10 +1264,8 @@ TEST_CASE(flatten_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0); mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), l0);
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), cl0); mm->add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), l0);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), cl1);
auto prog = optimize_onnx("flatten_test.onnx"); auto prog = optimize_onnx("flatten_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1308,9 +1306,7 @@ TEST_CASE(gather_test) ...@@ -1308,9 +1306,7 @@ TEST_CASE(gather_test)
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}}); auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
int axis = 1; int axis = 1;
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0); mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), cl0, cl1);
auto prog = optimize_onnx("gather_test.onnx"); auto prog = optimize_onnx("gather_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1330,13 +1326,11 @@ TEST_CASE(gather_elements_axis0_test) ...@@ -1330,13 +1326,11 @@ TEST_CASE(gather_elements_axis0_test)
auto l_ind_axis_indices = auto l_ind_axis_indices =
mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()}); mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {4}}); auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {4}});
auto cdata = mm->add_instruction(migraphx::make_op("contiguous"), data);
auto cindices = mm->add_instruction(migraphx::make_op("contiguous"), indices);
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), cdata); auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto lbst_stride = mm->add_instruction( auto lbst_stride = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride); migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), cindices, l_ind_axis_indices); auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices);
auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride); auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride);
auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta); auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta);
auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind); auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind);
...@@ -1360,13 +1354,12 @@ TEST_CASE(gather_elements_axis1_test) ...@@ -1360,13 +1354,12 @@ TEST_CASE(gather_elements_axis1_test)
mm->add_literal(migraphx::literal{ind_s, ind_indices.begin(), ind_indices.end()}); mm->add_literal(migraphx::literal{ind_s, ind_indices.begin(), ind_indices.end()});
auto l_ind_axis_indices = auto l_ind_axis_indices =
mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()}); mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {1}}); auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {1}});
auto cdata = mm->add_instruction(migraphx::make_op("contiguous"), data);
auto cindices = mm->add_instruction(migraphx::make_op("contiguous"), indices); auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), cdata);
auto lbst_stride = mm->add_instruction( auto lbst_stride = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride); migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), cindices, l_ind_axis_indices); auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices);
auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride); auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride);
auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta); auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta);
auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind); auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind);
...@@ -2630,14 +2623,8 @@ TEST_CASE(nms_test) ...@@ -2630,14 +2623,8 @@ TEST_CASE(nms_test)
migraphx::shape sst{migraphx::shape::float_type, {1}}; migraphx::shape sst{migraphx::shape::float_type, {1}};
auto st = mm->add_parameter("score_threshold", sst); auto st = mm->add_parameter("score_threshold", sst);
auto cb = mm->add_instruction(migraphx::make_op("contiguous"), b);
auto cs = mm->add_instruction(migraphx::make_op("contiguous"), s);
auto cmo = mm->add_instruction(migraphx::make_op("contiguous"), mo);
auto ciou = mm->add_instruction(migraphx::make_op("contiguous"), iou);
auto cst = mm->add_instruction(migraphx::make_op("contiguous"), st);
auto ret = mm->add_instruction( auto ret = mm->add_instruction(
migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}), cb, cs, cmo, ciou, cst); migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}), b, s, mo, iou, st);
mm->add_return({ret}); mm->add_return({ret});
auto prog = migraphx::parse_onnx("nms_test.onnx"); auto prog = migraphx::parse_onnx("nms_test.onnx");
...@@ -3408,12 +3395,10 @@ TEST_CASE(reshape_test) ...@@ -3408,12 +3395,10 @@ TEST_CASE(reshape_test)
std::vector<int64_t> reshape_dims{3, 8}; std::vector<int64_t> reshape_dims{3, 8};
mm->add_literal( mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims}); migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
op.dims = reshape_dims; op.dims = reshape_dims;
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0); mm->add_instruction(op, l0);
mm->add_instruction(op, cl0); mm->add_instruction(op, l0);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, cl1);
auto prog = optimize_onnx("reshape_test.onnx"); auto prog = optimize_onnx("reshape_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -3453,9 +3438,9 @@ TEST_CASE(resize_downsample_c_test) ...@@ -3453,9 +3438,9 @@ TEST_CASE(resize_downsample_c_test)
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}}; migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}};
std::vector<int> ind = {0, 2}; std::vector<int> ind = {0, 2};
auto li = mm->add_literal(migraphx::literal(si, ind)); auto li = mm->add_literal(migraphx::literal(si, ind));
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), cinx); auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r}); mm->add_return({r});
auto prog = migraphx::parse_onnx("resize_downsample_c_test.onnx"); auto prog = migraphx::parse_onnx("resize_downsample_c_test.onnx");
...@@ -3479,9 +3464,9 @@ TEST_CASE(resize_downsample_f_test) ...@@ -3479,9 +3464,9 @@ TEST_CASE(resize_downsample_f_test)
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}}; migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}};
std::vector<int> ind = {0, 3}; std::vector<int> ind = {0, 3};
auto li = mm->add_literal(migraphx::literal(si, ind)); auto li = mm->add_literal(migraphx::literal(si, ind));
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), cinx); auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r}); mm->add_return({r});
auto prog = migraphx::parse_onnx("resize_downsample_f_test.onnx"); auto prog = migraphx::parse_onnx("resize_downsample_f_test.onnx");
...@@ -3521,8 +3506,7 @@ TEST_CASE(resize_downsample_linear_test) ...@@ -3521,8 +3506,7 @@ TEST_CASE(resize_downsample_linear_test)
auto l1 = mm->add_literal(migraphx::literal(s1, d1)); auto l1 = mm->add_literal(migraphx::literal(s1, d1));
mm->add_instruction(migraphx::make_op("undefined")); mm->add_instruction(migraphx::make_op("undefined"));
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x); auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), x);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), cx);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind); auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto slc80 = mm->add_instruction( auto slc80 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
...@@ -3575,9 +3559,9 @@ TEST_CASE(resize_outsize_test) ...@@ -3575,9 +3559,9 @@ TEST_CASE(resize_outsize_test)
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}}; migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}};
std::vector<int> ind = {0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3}; std::vector<int> ind = {0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3};
auto li = mm->add_literal(migraphx::literal(si, ind)); auto li = mm->add_literal(migraphx::literal(si, ind));
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), cinx); auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r}); mm->add_return({r});
auto prog = migraphx::parse_onnx("resize_outsize_test.onnx"); auto prog = migraphx::parse_onnx("resize_outsize_test.onnx");
...@@ -3674,8 +3658,7 @@ TEST_CASE(resize_upsample_linear_ac_test) ...@@ -3674,8 +3658,7 @@ TEST_CASE(resize_upsample_linear_ac_test)
auto l1 = mm->add_literal(migraphx::literal(s1, d1)); auto l1 = mm->add_literal(migraphx::literal(s1, d1));
mm->add_instruction(migraphx::make_op("undefined")); mm->add_instruction(migraphx::make_op("undefined"));
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x); auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), cx);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind); auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto slc80 = mm->add_instruction( auto slc80 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
...@@ -3770,8 +3753,7 @@ TEST_CASE(resize_upsample_linear_test) ...@@ -3770,8 +3753,7 @@ TEST_CASE(resize_upsample_linear_test)
auto l1 = mm->add_literal(migraphx::literal(s1, d1)); auto l1 = mm->add_literal(migraphx::literal(s1, d1));
mm->add_instruction(migraphx::make_op("undefined")); mm->add_instruction(migraphx::make_op("undefined"));
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x); auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), cx);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind); auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto slc80 = mm->add_instruction( auto slc80 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
...@@ -3825,8 +3807,7 @@ TEST_CASE(resize_upsample_pc_test) ...@@ -3825,8 +3807,7 @@ TEST_CASE(resize_upsample_pc_test)
std::vector<int> ind = {0, 1, 1, 2, 3, 3, 0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 7, 7, 4, 5, 5, 6, 7, 7}; std::vector<int> ind = {0, 1, 1, 2, 3, 3, 0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 7, 7, 4, 5, 5, 6, 7, 7};
auto li = mm->add_literal(migraphx::literal(si, ind)); auto li = mm->add_literal(migraphx::literal(si, ind));
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx); auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), cinx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r}); mm->add_return({r});
...@@ -3853,8 +3834,7 @@ TEST_CASE(resize_upsample_pf_test) ...@@ -3853,8 +3834,7 @@ TEST_CASE(resize_upsample_pf_test)
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = mm->add_literal(migraphx::literal(si, ind)); auto li = mm->add_literal(migraphx::literal(si, ind));
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx); auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), cinx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r}); mm->add_return({r});
...@@ -3933,10 +3913,7 @@ TEST_CASE(scatter_test) ...@@ -3933,10 +3913,7 @@ TEST_CASE(scatter_test)
auto l2 = auto l2 =
mm->add_parameter("update", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); mm->add_parameter("update", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
int axis = -2; int axis = -2;
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0); auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", axis}}), l0, l1, l2);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
auto cl2 = mm->add_instruction(migraphx::make_op("contiguous"), l2);
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", axis}}), cl0, cl1, cl2);
mm->add_return({r}); mm->add_return({r});
auto prog = migraphx::parse_onnx("scatter_test.onnx"); auto prog = migraphx::parse_onnx("scatter_test.onnx");
...@@ -3999,9 +3976,7 @@ TEST_CASE(shape_gather_test) ...@@ -3999,9 +3976,7 @@ TEST_CASE(shape_gather_test)
auto l1 = auto l1 =
mm->add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens()); mm->add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
int axis = 0; int axis = 0;
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l1); mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l1, l2);
auto cl2 = mm->add_instruction(migraphx::make_op("contiguous"), l2);
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), cl1, cl2);
auto prog = optimize_onnx("shape_gather_test.onnx"); auto prog = optimize_onnx("shape_gather_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -4322,10 +4297,8 @@ TEST_CASE(squeeze_unsqueeze_test) ...@@ -4322,10 +4297,8 @@ TEST_CASE(squeeze_unsqueeze_test)
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5}; std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 = auto l0 =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}}); mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}});
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0); auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), cl0); mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), l1);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), cl1);
auto prog = optimize_onnx("squeeze_unsqueeze_test.onnx"); auto prog = optimize_onnx("squeeze_unsqueeze_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -4336,9 +4309,8 @@ TEST_CASE(squeeze_axes_input_test) ...@@ -4336,9 +4309,8 @@ TEST_CASE(squeeze_axes_input_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal({migraphx::shape::int64_type, {2}}, {1, 3})); mm->add_literal(migraphx::literal({migraphx::shape::int64_type, {2}}, {1, 3}));
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}}); auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}});
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0); auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 3}}}), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 3}}}), cl0);
mm->add_return({l1}); mm->add_return({l1});
auto prog = migraphx::parse_onnx("squeeze_axes_input_test.onnx"); auto prog = migraphx::parse_onnx("squeeze_axes_input_test.onnx");
...@@ -4351,9 +4323,8 @@ TEST_CASE(squeeze_empty_axes_test) ...@@ -4351,9 +4323,8 @@ TEST_CASE(squeeze_empty_axes_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
mm->add_literal({}); mm->add_literal({});
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}}); auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}});
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0); auto l1 = mm->add_instruction(migraphx::make_op("squeeze"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze"), cl0);
mm->add_return({l1}); mm->add_return({l1});
auto prog = migraphx::parse_onnx("squeeze_empty_axes_test.onnx"); auto prog = migraphx::parse_onnx("squeeze_empty_axes_test.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment