Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c79661d6
Commit
c79661d6
authored
Jan 27, 2022
by
Shucai Xiao
Browse files
fix review comments
parent
e35ad617
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
67 deletions
+51
-67
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+13
-1
src/include/migraphx/op/reduce_op.hpp
src/include/migraphx/op/reduce_op.hpp
+1
-1
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+1
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+36
-65
No files found.
src/auto_contiguous.cpp
View file @
c79661d6
...
@@ -20,7 +20,7 @@ void auto_contiguous::apply(module& p) const
...
@@ -20,7 +20,7 @@ void auto_contiguous::apply(module& p) const
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
auto
new_args
=
args
;
auto
new_args
=
args
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
new_args
.
begin
(),
[
&
](
auto
in
)
{
std
::
transform
(
args
.
begin
(),
args
.
end
(),
new_args
.
begin
(),
[
&
](
auto
in
)
{
return
p
.
replace
_instruction
(
ins
,
make_op
(
"contiguous"
),
in
);
return
p
.
insert
_instruction
(
ins
,
make_op
(
"contiguous"
),
in
);
});
});
if
(
new_args
!=
args
)
if
(
new_args
!=
args
)
...
@@ -29,6 +29,18 @@ void auto_contiguous::apply(module& p) const
...
@@ -29,6 +29,18 @@ void auto_contiguous::apply(module& p) const
}
}
}
}
}
}
auto
last
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
continue
;
shape
s
=
ins
->
get_shape
();
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
{
auto
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
p
.
replace_instruction
(
ins
,
c
);
}
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/op/reduce_op.hpp
View file @
c79661d6
...
@@ -66,7 +66,7 @@ struct reduce_op : op_name<Derived>
...
@@ -66,7 +66,7 @@ struct reduce_op : op_name<Derived>
{
{
value
normalize
;
value
normalize
;
normalize
[
"axes"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
normalize
[
"axes"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
return
{{
"normalize_axes"
,
normalize
}};
return
{{
"normalize_axes"
,
normalize
}
,
{
"standard_input_shape"
,
true
}
};
}
}
std
::
vector
<
int64_t
>
tune_axes
(
std
::
size_t
n_dim
)
const
std
::
vector
<
int64_t
>
tune_axes
(
std
::
size_t
n_dim
)
const
...
...
src/simplify_reshapes.cpp
View file @
c79661d6
...
@@ -26,6 +26,7 @@ const auto& reshaper_names()
...
@@ -26,6 +26,7 @@ const auto& reshaper_names()
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"flatten"
,
"flatten"
,
"reshape"
,
"reshape"
,
"contiguous"
,
"squeeze"
,
"squeeze"
,
"unsqueeze"
"unsqueeze"
};
};
...
...
test/onnx/onnx_test.cpp
View file @
c79661d6
...
@@ -1264,10 +1264,8 @@ TEST_CASE(flatten_test)
...
@@ -1264,10 +1264,8 @@ TEST_CASE(flatten_test)
migraphx::program p;
migraphx::program p;
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), l0);
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), cl0);
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), l0);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), cl1);
auto prog = optimize_onnx("flatten_test.onnx");
auto prog = optimize_onnx("flatten_test.onnx");
EXPECT(p == prog);
EXPECT(p == prog);
...
@@ -1308,9 +1306,7 @@ TEST_CASE(gather_test)
...
@@ -1308,9 +1306,7 @@ TEST_CASE(gather_test)
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
int axis = 1;
int axis = 1;
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), cl0, cl1);
auto prog = optimize_onnx("gather_test.onnx");
auto prog = optimize_onnx("gather_test.onnx");
EXPECT(p == prog);
EXPECT(p == prog);
...
@@ -1330,13 +1326,11 @@ TEST_CASE(gather_elements_axis0_test)
...
@@ -1330,13 +1326,11 @@ TEST_CASE(gather_elements_axis0_test)
auto l_ind_axis_indices =
auto l_ind_axis_indices =
mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {4}});
auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {4}});
auto cdata = mm->add_instruction(migraphx::make_op("contiguous"), data);
auto cindices = mm->add_instruction(migraphx::make_op("contiguous"), indices);
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}),
c
data);
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto lbst_stride = mm->add_instruction(
auto lbst_stride = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"),
c
indices, l_ind_axis_indices);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices);
auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride);
auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride);
auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta);
auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta);
auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind);
auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind);
...
@@ -1360,13 +1354,12 @@ TEST_CASE(gather_elements_axis1_test)
...
@@ -1360,13 +1354,12 @@ TEST_CASE(gather_elements_axis1_test)
mm->add_literal(migraphx::literal{ind_s, ind_indices.begin(), ind_indices.end()});
mm->add_literal(migraphx::literal{ind_s, ind_indices.begin(), ind_indices.end()});
auto l_ind_axis_indices =
auto l_ind_axis_indices =
mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {1}});
auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {1}});
auto cdata = mm->add_instruction(migraphx::make_op("contiguous"), data);
auto cindices = mm->add_instruction(migraphx::make_op("contiguous"), indices);
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), cdata);
auto lbst_stride = mm->add_instruction(
auto lbst_stride = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"),
c
indices, l_ind_axis_indices);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices);
auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride);
auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride);
auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta);
auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta);
auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind);
auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind);
...
@@ -2630,14 +2623,8 @@ TEST_CASE(nms_test)
...
@@ -2630,14 +2623,8 @@ TEST_CASE(nms_test)
migraphx::shape sst{migraphx::shape::float_type, {1}};
migraphx::shape sst{migraphx::shape::float_type, {1}};
auto st = mm->add_parameter("score_threshold", sst);
auto st = mm->add_parameter("score_threshold", sst);
auto cb = mm->add_instruction(migraphx::make_op("contiguous"), b);
auto cs = mm->add_instruction(migraphx::make_op("contiguous"), s);
auto cmo = mm->add_instruction(migraphx::make_op("contiguous"), mo);
auto ciou = mm->add_instruction(migraphx::make_op("contiguous"), iou);
auto cst = mm->add_instruction(migraphx::make_op("contiguous"), st);
auto ret = mm->add_instruction(
auto ret = mm->add_instruction(
migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}),
c
b,
c
s,
c
mo,
c
iou,
c
st);
migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}), b, s, mo, iou, st);
mm->add_return({ret});
mm->add_return({ret});
auto prog = migraphx::parse_onnx("nms_test.onnx");
auto prog = migraphx::parse_onnx("nms_test.onnx");
...
@@ -3408,12 +3395,10 @@ TEST_CASE(reshape_test)
...
@@ -3408,12 +3395,10 @@ TEST_CASE(reshape_test)
std::vector<int64_t> reshape_dims{3, 8};
std::vector<int64_t> reshape_dims{3, 8};
mm->add_literal(
mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
op.dims = reshape_dims;
op.dims = reshape_dims;
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, l0);
mm->add_instruction(op, cl0);
mm->add_instruction(op, l0);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, cl1);
auto prog = optimize_onnx("reshape_test.onnx");
auto prog = optimize_onnx("reshape_test.onnx");
EXPECT(p == prog);
EXPECT(p == prog);
...
@@ -3453,9 +3438,9 @@ TEST_CASE(resize_downsample_c_test)
...
@@ -3453,9 +3438,9 @@ TEST_CASE(resize_downsample_c_test)
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}};
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}};
std::vector<int> ind = {0, 2};
std::vector<int> ind = {0, 2};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto li = mm->add_literal(migraphx::literal(si, ind));
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx);
auto lrsp
= mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}),
c
inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
auto r
= mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
mm->add_return({r});
auto prog = migraphx::parse_onnx("resize_downsample_c_test.onnx");
auto prog = migraphx::parse_onnx("resize_downsample_c_test.onnx");
...
@@ -3479,9 +3464,9 @@ TEST_CASE(resize_downsample_f_test)
...
@@ -3479,9 +3464,9 @@ TEST_CASE(resize_downsample_f_test)
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}};
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}};
std::vector<int> ind = {0, 3};
std::vector<int> ind = {0, 3};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto li = mm->add_literal(migraphx::literal(si, ind));
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx);
auto lrsp
= mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}),
c
inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
auto r
= mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
mm->add_return({r});
auto prog = migraphx::parse_onnx("resize_downsample_f_test.onnx");
auto prog = migraphx::parse_onnx("resize_downsample_f_test.onnx");
...
@@ -3521,8 +3506,7 @@ TEST_CASE(resize_downsample_linear_test)
...
@@ -3521,8 +3506,7 @@ TEST_CASE(resize_downsample_linear_test)
auto l1 = mm->add_literal(migraphx::literal(s1, d1));
auto l1 = mm->add_literal(migraphx::literal(s1, d1));
mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(migraphx::make_op("undefined"));
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), x);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), cx);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto slc80 = mm->add_instruction(
auto slc80 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
...
@@ -3575,9 +3559,9 @@ TEST_CASE(resize_outsize_test)
...
@@ -3575,9 +3559,9 @@ TEST_CASE(resize_outsize_test)
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}};
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}};
std::vector<int> ind = {0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3};
std::vector<int> ind = {0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto li = mm->add_literal(migraphx::literal(si, ind));
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx);
auto lrsp
= mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}),
c
inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
auto r
= mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
mm->add_return({r});
auto prog = migraphx::parse_onnx("resize_outsize_test.onnx");
auto prog = migraphx::parse_onnx("resize_outsize_test.onnx");
...
@@ -3674,8 +3658,7 @@ TEST_CASE(resize_upsample_linear_ac_test)
...
@@ -3674,8 +3658,7 @@ TEST_CASE(resize_upsample_linear_ac_test)
auto l1 = mm->add_literal(migraphx::literal(s1, d1));
auto l1 = mm->add_literal(migraphx::literal(s1, d1));
mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(migraphx::make_op("undefined"));
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), cx);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto slc80 = mm->add_instruction(
auto slc80 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
...
@@ -3770,8 +3753,7 @@ TEST_CASE(resize_upsample_linear_test)
...
@@ -3770,8 +3753,7 @@ TEST_CASE(resize_upsample_linear_test)
auto l1 = mm->add_literal(migraphx::literal(s1, d1));
auto l1 = mm->add_literal(migraphx::literal(s1, d1));
mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(migraphx::make_op("undefined"));
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x);
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), cx);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto slc80 = mm->add_instruction(
auto slc80 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
...
@@ -3825,8 +3807,7 @@ TEST_CASE(resize_upsample_pc_test)
...
@@ -3825,8 +3807,7 @@ TEST_CASE(resize_upsample_pc_test)
std::vector<int> ind = {0, 1, 1, 2, 3, 3, 0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 7, 7, 4, 5, 5, 6, 7, 7};
std::vector<int> ind = {0, 1, 1, 2, 3, 3, 0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 7, 7, 4, 5, 5, 6, 7, 7};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto li = mm->add_literal(migraphx::literal(si, ind));
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), cinx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
mm->add_return({r});
...
@@ -3853,8 +3834,7 @@ TEST_CASE(resize_upsample_pf_test)
...
@@ -3853,8 +3834,7 @@ TEST_CASE(resize_upsample_pf_test)
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto li = mm->add_literal(migraphx::literal(si, ind));
auto cinx = mm->add_instruction(migraphx::make_op("contiguous"), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), cinx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
mm->add_return({r});
...
@@ -3933,10 +3913,7 @@ TEST_CASE(scatter_test)
...
@@ -3933,10 +3913,7 @@ TEST_CASE(scatter_test)
auto l2 =
auto l2 =
mm->add_parameter("update", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
mm->add_parameter("update", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
int axis = -2;
int axis = -2;
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", axis}}), l0, l1, l2);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
auto cl2 = mm->add_instruction(migraphx::make_op("contiguous"), l2);
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", axis}}), cl0, cl1, cl2);
mm->add_return({r});
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatter_test.onnx");
auto prog = migraphx::parse_onnx("scatter_test.onnx");
...
@@ -3999,9 +3976,7 @@ TEST_CASE(shape_gather_test)
...
@@ -3999,9 +3976,7 @@ TEST_CASE(shape_gather_test)
auto l1 =
auto l1 =
mm->add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
mm->add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
int axis = 0;
int axis = 0;
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l1, l2);
auto cl2 = mm->add_instruction(migraphx::make_op("contiguous"), l2);
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), cl1, cl2);
auto prog = optimize_onnx("shape_gather_test.onnx");
auto prog = optimize_onnx("shape_gather_test.onnx");
EXPECT(p == prog);
EXPECT(p == prog);
...
@@ -4322,10 +4297,8 @@ TEST_CASE(squeeze_unsqueeze_test)
...
@@ -4322,10 +4297,8 @@ TEST_CASE(squeeze_unsqueeze_test)
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 =
auto l0 =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}});
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}});
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), cl0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), l1);
auto cl1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), cl1);
auto prog = optimize_onnx("squeeze_unsqueeze_test.onnx");
auto prog = optimize_onnx("squeeze_unsqueeze_test.onnx");
EXPECT(p == prog);
EXPECT(p == prog);
...
@@ -4336,9 +4309,8 @@ TEST_CASE(squeeze_axes_input_test)
...
@@ -4336,9 +4309,8 @@ TEST_CASE(squeeze_axes_input_test)
migraphx::program p;
migraphx::program p;
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal({migraphx::shape::int64_type, {2}}, {1, 3}));
mm->add_literal(migraphx::literal({migraphx::shape::int64_type, {2}}, {1, 3}));
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}});
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}});
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 3}}}), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 3}}}), cl0);
mm->add_return({l1});
mm->add_return({l1});
auto prog = migraphx::parse_onnx("squeeze_axes_input_test.onnx");
auto prog = migraphx::parse_onnx("squeeze_axes_input_test.onnx");
...
@@ -4351,9 +4323,8 @@ TEST_CASE(squeeze_empty_axes_test)
...
@@ -4351,9 +4323,8 @@ TEST_CASE(squeeze_empty_axes_test)
migraphx::program p;
migraphx::program p;
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
mm->add_literal({});
mm->add_literal({});
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}});
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}});
auto cl0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze"), cl0);
mm->add_return({l1});
mm->add_return({l1});
auto prog = migraphx::parse_onnx("squeeze_empty_axes_test.onnx");
auto prog = migraphx::parse_onnx("squeeze_empty_axes_test.onnx");
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment