Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c79661d6
Commit
c79661d6
authored
Jan 27, 2022
by
Shucai Xiao
Browse files
fix review comments
parent
e35ad617
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
67 deletions
+51
-67
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+13
-1
src/include/migraphx/op/reduce_op.hpp
src/include/migraphx/op/reduce_op.hpp
+1
-1
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+1
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+36
-65
No files found.
src/auto_contiguous.cpp
View file @
c79661d6
...
@@ -20,7 +20,7 @@ void auto_contiguous::apply(module& p) const
...
@@ -20,7 +20,7 @@ void auto_contiguous::apply(module& p) const
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
auto
new_args
=
args
;
auto
new_args
=
args
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
new_args
.
begin
(),
[
&
](
auto
in
)
{
std
::
transform
(
args
.
begin
(),
args
.
end
(),
new_args
.
begin
(),
[
&
](
auto
in
)
{
return
p
.
replace
_instruction
(
ins
,
make_op
(
"contiguous"
),
in
);
return
p
.
insert
_instruction
(
ins
,
make_op
(
"contiguous"
),
in
);
});
});
if
(
new_args
!=
args
)
if
(
new_args
!=
args
)
...
@@ -29,6 +29,18 @@ void auto_contiguous::apply(module& p) const
...
@@ -29,6 +29,18 @@ void auto_contiguous::apply(module& p) const
}
}
}
}
}
}
auto
last
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
continue
;
shape
s
=
ins
->
get_shape
();
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
{
auto
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
p
.
replace_instruction
(
ins
,
c
);
}
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/op/reduce_op.hpp
View file @
c79661d6
...
@@ -66,7 +66,7 @@ struct reduce_op : op_name<Derived>
...
@@ -66,7 +66,7 @@ struct reduce_op : op_name<Derived>
{
{
value
normalize
;
value
normalize
;
normalize
[
"axes"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
normalize
[
"axes"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
return
{{
"normalize_axes"
,
normalize
}};
return
{{
"normalize_axes"
,
normalize
}
,
{
"standard_input_shape"
,
true
}
};
}
}
std
::
vector
<
int64_t
>
tune_axes
(
std
::
size_t
n_dim
)
const
std
::
vector
<
int64_t
>
tune_axes
(
std
::
size_t
n_dim
)
const
...
...
src/simplify_reshapes.cpp
View file @
c79661d6
...
@@ -26,6 +26,7 @@ const auto& reshaper_names()
...
@@ -26,6 +26,7 @@ const auto& reshaper_names()
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"flatten"
,
"flatten"
,
"reshape"
,
"reshape"
,
"contiguous"
,
"squeeze"
,
"squeeze"
,
"unsqueeze"
"unsqueeze"
};
};
...
...
test/onnx/onnx_test.cpp
View file @
c79661d6
...
@@ -1264,10 +1264,8 @@ TEST_CASE(flatten_test)
...
@@ -1264,10 +1264,8 @@ TEST_CASE(flatten_test)
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
auto
l0
=
mm
->
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
auto
cl0
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
2
}}),
l0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
2
}}),
cl0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
1
}}),
l0
);
auto
cl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
1
}}),
cl1
);
auto
prog
=
optimize_onnx
(
"flatten_test.onnx"
);
auto
prog
=
optimize_onnx
(
"flatten_test.onnx"
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
@@ -1308,9 +1306,7 @@ TEST_CASE(gather_test)
...
@@ -1308,9 +1306,7 @@ TEST_CASE(gather_test)
auto
l0
=
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
auto
l0
=
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
auto
l1
=
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}});
auto
l1
=
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}});
int
axis
=
1
;
int
axis
=
1
;
auto
cl0
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
l0
,
l1
);
auto
cl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
cl0
,
cl1
);
auto
prog
=
optimize_onnx
(
"gather_test.onnx"
);
auto
prog
=
optimize_onnx
(
"gather_test.onnx"
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
@@ -1330,13 +1326,11 @@ TEST_CASE(gather_elements_axis0_test)
...
@@ -1330,13 +1326,11 @@ TEST_CASE(gather_elements_axis0_test)
auto
l_ind_axis_indices
=
auto
l_ind_axis_indices
=
mm
->
add_literal
(
migraphx
::
literal
{
ind_s
,
ind_axis_indices
.
begin
(),
ind_axis_indices
.
end
()});
mm
->
add_literal
(
migraphx
::
literal
{
ind_s
,
ind_axis_indices
.
begin
(),
ind_axis_indices
.
end
()});
auto
l_stride
=
mm
->
add_literal
(
migraphx
::
literal
{{
migraphx
::
shape
::
int32_type
,
{
1
}},
{
4
}});
auto
l_stride
=
mm
->
add_literal
(
migraphx
::
literal
{{
migraphx
::
shape
::
int32_type
,
{
1
}},
{
4
}});
auto
cdata
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
data
);
auto
cindices
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
indices
);
auto
rsp_data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
12
}}}),
c
data
);
auto
rsp_data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
12
}}}),
data
);
auto
lbst_stride
=
mm
->
add_instruction
(
auto
lbst_stride
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ind_s
.
lens
()}}),
l_stride
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ind_s
.
lens
()}}),
l_stride
);
auto
axis_delta
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"sub"
),
c
indices
,
l_ind_axis_indices
);
auto
axis_delta
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"sub"
),
indices
,
l_ind_axis_indices
);
auto
mul_delta
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
axis_delta
,
lbst_stride
);
auto
mul_delta
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
axis_delta
,
lbst_stride
);
auto
ind
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
l_data_indices
,
mul_delta
);
auto
ind
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
l_data_indices
,
mul_delta
);
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
rsp_data
,
ind
);
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
rsp_data
,
ind
);
...
@@ -1361,12 +1355,11 @@ TEST_CASE(gather_elements_axis1_test)
...
@@ -1361,12 +1355,11 @@ TEST_CASE(gather_elements_axis1_test)
auto
l_ind_axis_indices
=
auto
l_ind_axis_indices
=
mm
->
add_literal
(
migraphx
::
literal
{
ind_s
,
ind_axis_indices
.
begin
(),
ind_axis_indices
.
end
()});
mm
->
add_literal
(
migraphx
::
literal
{
ind_s
,
ind_axis_indices
.
begin
(),
ind_axis_indices
.
end
()});
auto
l_stride
=
mm
->
add_literal
(
migraphx
::
literal
{{
migraphx
::
shape
::
int32_type
,
{
1
}},
{
1
}});
auto
l_stride
=
mm
->
add_literal
(
migraphx
::
literal
{{
migraphx
::
shape
::
int32_type
,
{
1
}},
{
1
}});
auto
cdata
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
data
);
auto
cindices
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
indices
);
auto
rsp_data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
12
}}}),
data
);
auto
rsp_data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
12
}}}),
cdata
);
auto
lbst_stride
=
mm
->
add_instruction
(
auto
lbst_stride
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ind_s
.
lens
()}}),
l_stride
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ind_s
.
lens
()}}),
l_stride
);
auto
axis_delta
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"sub"
),
c
indices
,
l_ind_axis_indices
);
auto
axis_delta
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"sub"
),
indices
,
l_ind_axis_indices
);
auto
mul_delta
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
axis_delta
,
lbst_stride
);
auto
mul_delta
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
axis_delta
,
lbst_stride
);
auto
ind
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
l_data_indices
,
mul_delta
);
auto
ind
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
l_data_indices
,
mul_delta
);
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
rsp_data
,
ind
);
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
rsp_data
,
ind
);
...
@@ -2630,14 +2623,8 @@ TEST_CASE(nms_test)
...
@@ -2630,14 +2623,8 @@ TEST_CASE(nms_test)
migraphx
::
shape
sst
{
migraphx
::
shape
::
float_type
,
{
1
}};
migraphx
::
shape
sst
{
migraphx
::
shape
::
float_type
,
{
1
}};
auto
st
=
mm
->
add_parameter
(
"score_threshold"
,
sst
);
auto
st
=
mm
->
add_parameter
(
"score_threshold"
,
sst
);
auto
cb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
b
);
auto
cs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
s
);
auto
cmo
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
mo
);
auto
ciou
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
iou
);
auto
cst
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
st
);
auto
ret
=
mm
->
add_instruction
(
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"nonmaxsuppression"
,
{{
"center_point_box"
,
1
}}),
c
b
,
c
s
,
c
mo
,
c
iou
,
c
st
);
migraphx
::
make_op
(
"nonmaxsuppression"
,
{{
"center_point_box"
,
1
}}),
b
,
s
,
mo
,
iou
,
st
);
mm
->
add_return
({
ret
});
mm
->
add_return
({
ret
});
auto
prog
=
migraphx
::
parse_onnx
(
"nms_test.onnx"
);
auto
prog
=
migraphx
::
parse_onnx
(
"nms_test.onnx"
);
...
@@ -3410,10 +3397,8 @@ TEST_CASE(reshape_test)
...
@@ -3410,10 +3397,8 @@ TEST_CASE(reshape_test)
migraphx
::
literal
{
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
}},
reshape_dims
});
migraphx
::
literal
{
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
}},
reshape_dims
});
auto
l0
=
mm
->
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
3
}});
auto
l0
=
mm
->
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
3
}});
op
.
dims
=
reshape_dims
;
op
.
dims
=
reshape_dims
;
auto
cl0
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l0
);
mm
->
add_instruction
(
op
,
l0
);
mm
->
add_instruction
(
op
,
cl0
);
mm
->
add_instruction
(
op
,
l0
);
auto
cl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l0
);
mm
->
add_instruction
(
op
,
cl1
);
auto
prog
=
optimize_onnx
(
"reshape_test.onnx"
);
auto
prog
=
optimize_onnx
(
"reshape_test.onnx"
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
@@ -3453,8 +3438,8 @@ TEST_CASE(resize_downsample_c_test)
...
@@ -3453,8 +3438,8 @@ TEST_CASE(resize_downsample_c_test)
migraphx
::
shape
si
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
2
}};
migraphx
::
shape
si
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
2
}};
std
::
vector
<
int
>
ind
=
{
0
,
2
};
std
::
vector
<
int
>
ind
=
{
0
,
2
};
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
(
si
,
ind
));
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
(
si
,
ind
));
auto
cinx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
inx
);
auto
lrsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
8
}}}),
c
inx
);
auto
lrsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
8
}}}),
inx
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
lrsp
,
li
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
lrsp
,
li
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
...
@@ -3479,8 +3464,8 @@ TEST_CASE(resize_downsample_f_test)
...
@@ -3479,8 +3464,8 @@ TEST_CASE(resize_downsample_f_test)
migraphx
::
shape
si
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
2
}};
migraphx
::
shape
si
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
2
}};
std
::
vector
<
int
>
ind
=
{
0
,
3
};
std
::
vector
<
int
>
ind
=
{
0
,
3
};
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
(
si
,
ind
));
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
(
si
,
ind
));
auto
cinx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
inx
);
auto
lrsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
8
}}}),
c
inx
);
auto
lrsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
8
}}}),
inx
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
lrsp
,
li
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
lrsp
,
li
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
...
@@ -3521,8 +3506,7 @@ TEST_CASE(resize_downsample_linear_test)
...
@@ -3521,8 +3506,7 @@ TEST_CASE(resize_downsample_linear_test)
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
(
s1
,
d1
));
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
(
s1
,
d1
));
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
auto
cx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
auto
rsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
8
}}}),
x
);
auto
rsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
8
}}}),
cx
);
auto
data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
rsp
,
l_ind
);
auto
data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
rsp
,
l_ind
);
auto
slc80
=
mm
->
add_instruction
(
auto
slc80
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
8
}}}),
data
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
8
}}}),
data
);
...
@@ -3575,8 +3559,8 @@ TEST_CASE(resize_outsize_test)
...
@@ -3575,8 +3559,8 @@ TEST_CASE(resize_outsize_test)
migraphx
::
shape
si
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
4
,
6
}};
migraphx
::
shape
si
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
4
,
6
}};
std
::
vector
<
int
>
ind
=
{
0
,
0
,
1
,
1
,
1
,
1
,
2
,
2
,
3
,
3
,
3
,
3
,
2
,
2
,
3
,
3
,
3
,
3
,
2
,
2
,
3
,
3
,
3
,
3
};
std
::
vector
<
int
>
ind
=
{
0
,
0
,
1
,
1
,
1
,
1
,
2
,
2
,
3
,
3
,
3
,
3
,
2
,
2
,
3
,
3
,
3
,
3
,
2
,
2
,
3
,
3
,
3
,
3
};
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
(
si
,
ind
));
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
(
si
,
ind
));
auto
cinx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
inx
);
auto
lrsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
}}}),
c
inx
);
auto
lrsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
}}}),
inx
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
lrsp
,
li
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
lrsp
,
li
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
...
@@ -3674,8 +3658,7 @@ TEST_CASE(resize_upsample_linear_ac_test)
...
@@ -3674,8 +3658,7 @@ TEST_CASE(resize_upsample_linear_ac_test)
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
(
s1
,
d1
));
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
(
s1
,
d1
));
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
auto
cx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
auto
rsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
}}}),
x
);
auto
rsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
}}}),
cx
);
auto
data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
rsp
,
l_ind
);
auto
data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
rsp
,
l_ind
);
auto
slc80
=
mm
->
add_instruction
(
auto
slc80
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
8
}}}),
data
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
8
}}}),
data
);
...
@@ -3770,8 +3753,7 @@ TEST_CASE(resize_upsample_linear_test)
...
@@ -3770,8 +3753,7 @@ TEST_CASE(resize_upsample_linear_test)
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
(
s1
,
d1
));
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
(
s1
,
d1
));
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
auto
cx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
auto
rsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
}}}),
x
);
auto
rsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
}}}),
cx
);
auto
data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
rsp
,
l_ind
);
auto
data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
rsp
,
l_ind
);
auto
slc80
=
mm
->
add_instruction
(
auto
slc80
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
8
}}}),
data
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
8
}}}),
data
);
...
@@ -3825,8 +3807,7 @@ TEST_CASE(resize_upsample_pc_test)
...
@@ -3825,8 +3807,7 @@ TEST_CASE(resize_upsample_pc_test)
std
::
vector
<
int
>
ind
=
{
0
,
1
,
1
,
2
,
3
,
3
,
0
,
1
,
1
,
2
,
3
,
3
,
4
,
5
,
5
,
6
,
7
,
7
,
4
,
5
,
5
,
6
,
7
,
7
};
std
::
vector
<
int
>
ind
=
{
0
,
1
,
1
,
2
,
3
,
3
,
0
,
1
,
1
,
2
,
3
,
3
,
4
,
5
,
5
,
6
,
7
,
7
,
4
,
5
,
5
,
6
,
7
,
7
};
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
(
si
,
ind
));
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
(
si
,
ind
));
auto
cinx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
inx
);
auto
lrsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
8
}}}),
inx
);
auto
lrsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
8
}}}),
cinx
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
lrsp
,
li
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
lrsp
,
li
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
...
@@ -3853,8 +3834,7 @@ TEST_CASE(resize_upsample_pf_test)
...
@@ -3853,8 +3834,7 @@ TEST_CASE(resize_upsample_pf_test)
std
::
vector
<
int
>
ind
=
{
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
3
,
2
,
2
,
2
,
3
,
3
,
3
};
std
::
vector
<
int
>
ind
=
{
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
3
,
2
,
2
,
2
,
3
,
3
,
3
};
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
(
si
,
ind
));
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
(
si
,
ind
));
auto
cinx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
inx
);
auto
lrsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
}}}),
inx
);
auto
lrsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
}}}),
cinx
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
lrsp
,
li
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
lrsp
,
li
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
...
@@ -3933,10 +3913,7 @@ TEST_CASE(scatter_test)
...
@@ -3933,10 +3913,7 @@ TEST_CASE(scatter_test)
auto
l2
=
auto
l2
=
mm
->
add_parameter
(
"update"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
mm
->
add_parameter
(
"update"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
int
axis
=
-
2
;
int
axis
=
-
2
;
auto
cl0
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l0
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatter"
,
{{
"axis"
,
axis
}}),
l0
,
l1
,
l2
);
auto
cl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l1
);
auto
cl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l2
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatter"
,
{{
"axis"
,
axis
}}),
cl0
,
cl1
,
cl2
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
auto
prog
=
migraphx
::
parse_onnx
(
"scatter_test.onnx"
);
auto
prog
=
migraphx
::
parse_onnx
(
"scatter_test.onnx"
);
...
@@ -3999,9 +3976,7 @@ TEST_CASE(shape_gather_test)
...
@@ -3999,9 +3976,7 @@ TEST_CASE(shape_gather_test)
auto
l1
=
auto
l1
=
mm
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
3
}},
l0
->
get_shape
().
lens
());
mm
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
3
}},
l0
->
get_shape
().
lens
());
int
axis
=
0
;
int
axis
=
0
;
auto
cl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
l1
,
l2
);
auto
cl2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
cl1
,
cl2
);
auto
prog
=
optimize_onnx
(
"shape_gather_test.onnx"
);
auto
prog
=
optimize_onnx
(
"shape_gather_test.onnx"
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
@@ -4322,10 +4297,8 @@ TEST_CASE(squeeze_unsqueeze_test)
...
@@ -4322,10 +4297,8 @@ TEST_CASE(squeeze_unsqueeze_test)
std
::
vector
<
int64_t
>
unsqueeze_axes
{
0
,
1
,
3
,
5
};
std
::
vector
<
int64_t
>
unsqueeze_axes
{
0
,
1
,
3
,
5
};
auto
l0
=
auto
l0
=
mm
->
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
1
,
1
,
2
,
1
}});
mm
->
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
1
,
1
,
2
,
1
}});
auto
cl0
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l0
);
auto
l1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
squeeze_axes
}}),
l0
);
auto
l1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
squeeze_axes
}}),
cl0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
l1
);
auto
cl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
cl1
);
auto
prog
=
optimize_onnx
(
"squeeze_unsqueeze_test.onnx"
);
auto
prog
=
optimize_onnx
(
"squeeze_unsqueeze_test.onnx"
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
@@ -4337,8 +4310,7 @@ TEST_CASE(squeeze_axes_input_test)
...
@@ -4337,8 +4310,7 @@ TEST_CASE(squeeze_axes_input_test)
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
mm
->
add_literal
(
migraphx
::
literal
({
migraphx
::
shape
::
int64_type
,
{
2
}},
{
1
,
3
}));
mm
->
add_literal
(
migraphx
::
literal
({
migraphx
::
shape
::
int64_type
,
{
2
}},
{
1
,
3
}));
auto
l0
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
1
,
5
,
1
}});
auto
l0
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
1
,
5
,
1
}});
auto
cl0
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l0
);
auto
l1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
1
,
3
}}}),
l0
);
auto
l1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
1
,
3
}}}),
cl0
);
mm
->
add_return
({
l1
});
mm
->
add_return
({
l1
});
auto
prog
=
migraphx
::
parse_onnx
(
"squeeze_axes_input_test.onnx"
);
auto
prog
=
migraphx
::
parse_onnx
(
"squeeze_axes_input_test.onnx"
);
...
@@ -4352,8 +4324,7 @@ TEST_CASE(squeeze_empty_axes_test)
...
@@ -4352,8 +4324,7 @@ TEST_CASE(squeeze_empty_axes_test)
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
mm
->
add_literal
({});
mm
->
add_literal
({});
auto
l0
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
1
,
5
,
1
}});
auto
l0
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
1
,
5
,
1
}});
auto
cl0
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l0
);
auto
l1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
),
l0
);
auto
l1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
),
cl0
);
mm
->
add_return
({
l1
});
mm
->
add_return
({
l1
});
auto
prog
=
migraphx
::
parse_onnx
(
"squeeze_empty_axes_test.onnx"
);
auto
prog
=
migraphx
::
parse_onnx
(
"squeeze_empty_axes_test.onnx"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment