Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
86331dbd
Commit
86331dbd
authored
Jan 24, 2022
by
Shucai Xiao
Browse files
refine onnx unit tests
parent
6195c942
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
78 additions
and
29 deletions
+78
-29
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+14
-0
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+5
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+57
-28
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+2
-1
No files found.
src/eliminate_contiguous.cpp
View file @
86331dbd
...
...
@@ -75,7 +75,21 @@ void eliminate_contiguous::apply(module& p) const
{
// return instruction should have inputs with standard shape
if
(
ins
->
name
()
==
"@return"
)
{
auto
args
=
ins
->
inputs
();
std
::
transform
(
args
.
begin
(),
args
.
end
(),
args
.
begin
(),
[
&
](
auto
in
)
{
if
(
in
->
name
()
!=
op_name
)
return
in
;
auto
prev
=
in
->
inputs
().
front
();
return
prev
->
get_shape
().
standard
()
?
prev
:
in
;
});
if
(
args
!=
ins
->
inputs
())
{
p
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
args
);
}
continue
;
}
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
inputs
();
...
...
src/onnx/onnx_parser.cpp
View file @
86331dbd
...
...
@@ -70,6 +70,11 @@ static literal from_repeated(shape::type_t t, const T& r)
instruction_ref
onnx_parser
::
node_info
::
make_contiguous
(
instruction_ref
ins
)
const
{
if
(
ins
->
name
()
==
"contiguous"
)
{
return
ins
;
}
return
add_instruction
(
make_op
(
"contiguous"
),
ins
);
}
...
...
test/onnx/onnx_test.cpp
View file @
86331dbd
...
...
@@ -1264,8 +1264,10 @@ TEST_CASE(flatten_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), l0);
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), l0);
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), cl0);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), cl1);
auto prog = optimize_onnx("flatten_test.onnx");
EXPECT(p == prog);
...
...
@@ -1306,7 +1308,9 @@ TEST_CASE(gather_test)
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
int axis = 1;
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1);
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), cl0, cl1);
auto prog = optimize_onnx("gather_test.onnx");
EXPECT(p == prog);
...
...
@@ -1326,11 +1330,13 @@ TEST_CASE(gather_elements_axis0_test)
auto l_ind_axis_indices =
mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {4}});
auto cdata = mm->add_instruction(migraphx::make_op("contiguous"), data);
auto cindices = mm->add_instruction(migraphx::make_op("contiguous"), indices);
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}),
c
data);
auto lbst_stride = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"),
c
indices, l_ind_axis_indices);
auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride);
auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta);
auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind);
...
...
@@ -1355,11 +1361,12 @@ TEST_CASE(gather_elements_axis1_test)
auto l_ind_axis_indices =
mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {1}});
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto cdata = mm->add_instruction(migraphx::make_op("contiguous"), data);
auto cindices = mm->add_instruction(migraphx::make_op("contiguous"), indices);
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), cdata);
auto lbst_stride = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"),
c
indices, l_ind_axis_indices);
auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride);
auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta);
auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind);
...
...
@@ -2588,8 +2595,14 @@ TEST_CASE(nms_test)
migraphx::shape sst{migraphx::shape::float_type, {1}};
auto st = mm->add_parameter("score_threshold", sst);
auto cb = mm->add_instruction(migraphx::make_op("contiguous"), b);
auto cs = mm->add_instruction(migraphx::make_op("contiguous"), s);
auto cmo = mm->add_instruction(migraphx::make_op("contiguous"), mo);
auto ciou = mm->add_instruction(migraphx::make_op("contiguous"), iou);
auto cst = mm->add_instruction(migraphx::make_op("contiguous"), st);
auto ret = mm->add_instruction(
migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}), b, s, mo, iou, st);
migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}),
c
b,
c
s,
c
mo,
c
iou,
c
st);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("nms_test.onnx");
...
...
@@ -3362,8 +3375,10 @@ TEST_CASE(reshape_test)
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
op.dims = reshape_dims;
mm->add_instruction(op, l0);
mm->add_instruction(op, l0);
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, cl0);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, cl1);
auto prog = optimize_onnx("reshape_test.onnx");
EXPECT(p == prog);
...
...
@@ -3403,8 +3418,8 @@ TEST_CASE(resize_downsample_c_test)
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}};
std::vector<int> ind = {0, 2};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx);
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}),
c
inx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
...
...
@@ -3429,8 +3444,8 @@ TEST_CASE(resize_downsample_f_test)
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}};
std::vector<int> ind = {0, 3};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx);
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}),
c
inx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
...
...
@@ -3471,7 +3486,8 @@ TEST_CASE(resize_downsample_linear_test)
auto l1 = mm->add_literal(migraphx::literal(s1, d1));
mm->add_instruction(migraphx::make_op("undefined"));
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), x);
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), cx);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto slc80 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
...
...
@@ -3524,8 +3540,8 @@ TEST_CASE(resize_outsize_test)
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}};
std::vector<int> ind = {0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx);
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}),
c
inx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
...
...
@@ -3623,7 +3639,8 @@ TEST_CASE(resize_upsample_linear_ac_test)
auto l1 = mm->add_literal(migraphx::literal(s1, d1));
mm->add_instruction(migraphx::make_op("undefined"));
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x);
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), cx);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto slc80 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
...
...
@@ -3718,7 +3735,8 @@ TEST_CASE(resize_upsample_linear_test)
auto l1 = mm->add_literal(migraphx::literal(s1, d1));
mm->add_instruction(migraphx::make_op("undefined"));
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x);
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), cx);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto slc80 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
...
...
@@ -3772,7 +3790,8 @@ TEST_CASE(resize_upsample_pc_test)
std::vector<int> ind = {0, 1, 1, 2, 3, 3, 0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 7, 7, 4, 5, 5, 6, 7, 7};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx);
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), cinx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
...
...
@@ -3799,7 +3818,8 @@ TEST_CASE(resize_upsample_pf_test)
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx);
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), cinx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
...
...
@@ -3878,7 +3898,10 @@ TEST_CASE(scatter_test)
auto l2 =
mm->add_parameter("update", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
int axis = -2;
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", axis}}), l0, l1, l2);
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
auto cl2 = mm->add_instruction(migraphx::make_op("contiguous"), l2);
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", axis}}), cl0, cl1, cl2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatter_test.onnx");
...
...
@@ -3941,7 +3964,9 @@ TEST_CASE(shape_gather_test)
auto l1 =
mm->add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
int axis = 0;
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l1, l2);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
auto cl2 = mm->add_instruction(migraphx::make_op("contiguous"), l2);
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), cl1, cl2);
auto prog = optimize_onnx("shape_gather_test.onnx");
EXPECT(p == prog);
...
...
@@ -4262,8 +4287,10 @@ TEST_CASE(squeeze_unsqueeze_test)
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}});
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), l1);
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), cl0);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), cl1);
auto prog = optimize_onnx("squeeze_unsqueeze_test.onnx");
EXPECT(p == prog);
...
...
@@ -4275,7 +4302,8 @@ TEST_CASE(squeeze_axes_input_test)
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal({migraphx::shape::int64_type, {2}}, {1, 3}));
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}});
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 3}}}), l0);
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 3}}}), cl0);
mm->add_return({l1});
auto prog = migraphx::parse_onnx("squeeze_axes_input_test.onnx");
...
...
@@ -4289,7 +4317,8 @@ TEST_CASE(squeeze_empty_axes_test)
auto* mm = p.get_main_module();
mm->add_literal({});
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}});
auto l1 = mm->add_instruction(migraphx::make_op("squeeze"), l0);
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze"), cl0);
mm->add_return({l1});
auto prog = migraphx::parse_onnx("squeeze_empty_axes_test.onnx");
...
...
test/simplify_reshapes_test.cpp
View file @
86331dbd
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
...
...
@@ -13,7 +14,7 @@
void
run_pass
(
migraphx
::
module
&
m
)
{
migraphx
::
run_passes
(
m
,
{
migraphx
::
simplify_reshapes
{},
migraphx
::
dead_code_elimination
{}});
migraphx
::
run_passes
(
m
,
{
migraphx
::
simplify_reshapes
{},
migraphx
::
eliminate_contiguous
{
"contiguous"
},
migraphx
::
dead_code_elimination
{}});
}
TEST_CASE
(
double_contig
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment