Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
60aa1c85
Unverified
Commit
60aa1c85
authored
Jan 21, 2022
by
turneram
Committed by
GitHub
Jan 21, 2022
Browse files
GreaterOrEqual ONNX parser (#1044)
Add onnx parser for operator GreaterOrEqual
parent
ebb15dd3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
102 additions
and
4 deletions
+102
-4
src/onnx/parse_greaterorequal.cpp
src/onnx/parse_greaterorequal.cpp
+31
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+16
-0
test/onnx/greaterorequal_test.onnx
test/onnx/greaterorequal_test.onnx
+16
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+18
-0
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+21
-0
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-4
No files found.
src/onnx/parse_greaterorequal.cpp
0 → 100644
View file @
60aa1c85
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_greaterorequal
:
op_parser
<
parse_greaterorequal
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"GreaterOrEqual"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
in_res
=
info
.
add_broadcastable_binary_op
(
"less"
,
args
[
0
],
args
[
1
]);
if
(
in_res
->
get_shape
().
type
()
!=
shape
::
bool_type
)
{
in_res
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
bool_type
}}),
in_res
);
}
return
info
.
add_instruction
(
make_op
(
"not"
),
in_res
);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
test/onnx/gen_onnx.py
View file @
60aa1c85
...
...
@@ -1618,6 +1618,22 @@ def greater_bool_test():
return
([
node1
,
node2
],
[
x1
,
x2
],
[
y
])
@
onnx_test
def
greaterorequal_test
():
x1
=
helper
.
make_tensor_value_info
(
'x1'
,
TensorProto
.
FLOAT
,
[
3
])
x2
=
helper
.
make_tensor_value_info
(
'x2'
,
TensorProto
.
FLOAT
,
[
3
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
3
])
node
=
onnx
.
helper
.
make_node
(
'GreaterOrEqual'
,
inputs
=
[
'x1'
,
'x2'
],
outputs
=
[
'y'
],
)
return
([
node
],
[
x1
,
x2
],
[
y
])
@
onnx_test
def
group_conv_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
1
,
4
,
16
,
16
])
...
...
test/onnx/greaterorequal_test.onnx
0 → 100644
View file @
60aa1c85
greaterorequal_test:g
x1
x2y"GreaterOrEqualgreaterorequal_testZ
x1
Z
x2
b
y
B
\ No newline at end of file
test/onnx/onnx_test.cpp
View file @
60aa1c85
...
...
@@ -1549,6 +1549,24 @@ TEST_CASE(greater_bool_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
greaterorequal_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
input1
=
mm
->
add_parameter
(
"x1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
}});
auto
input2
=
mm
->
add_parameter
(
"x2"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
}});
auto
temp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"less"
),
input1
,
input2
);
auto
bt
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
bool_type
}}),
temp
);
auto
ge
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"not"
),
bt
);
mm
->
add_return
({
ge
});
auto
prog
=
migraphx
::
parse_onnx
(
"greaterorequal_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
group_conv_test
)
{
migraphx
::
program
p
;
...
...
test/onnx/verify_onnx.cpp
View file @
60aa1c85
...
...
@@ -126,6 +126,27 @@ TEST_CASE(gather_elements)
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
}
TEST_CASE
(
greaterorequal_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"greaterorequal_test.onnx"
);
p
.
compile
(
migraphx
::
ref
::
target
{});
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
std
::
vector
<
float
>
data1
=
{
0.25
,
0.75
,
0.9375
};
std
::
vector
<
float
>
data2
=
{
0.25
,
0.74
,
0.9411
};
migraphx
::
parameter_map
pp
;
pp
[
"x1"
]
=
migraphx
::
argument
(
s
,
data1
.
data
());
pp
[
"x2"
]
=
migraphx
::
argument
(
s
,
data2
.
data
());
auto
result
=
p
.
eval
(
pp
).
back
();
std
::
vector
<
float
>
result_vector
;
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
1.0
,
1.0
,
0.0
};
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
}
TEST_CASE
(
hardsigmoid_verify_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"hardsigmoid_verify_test.onnx"
);
...
...
test/py/onnx_backend_test.py
View file @
60aa1c85
...
...
@@ -266,10 +266,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test
.
exclude
(
r
'test_gathernd_example_float32_cpu'
)
backend_test
.
exclude
(
r
'test_gathernd_example_int32_batch_dim1_cpu'
)
backend_test
.
exclude
(
r
'test_gathernd_example_int32_cpu'
)
backend_test
.
exclude
(
r
'test_greater_equal_bcast_cpu'
)
backend_test
.
exclude
(
r
'test_greater_equal_bcast_expanded_cpu'
)
backend_test
.
exclude
(
r
'test_greater_equal_cpu'
)
backend_test
.
exclude
(
r
'test_greater_equal_expanded_cpu'
)
backend_test
.
exclude
(
r
'test_identity_sequence_cpu'
)
backend_test
.
exclude
(
r
'test_maxpool_2d_uint8_cpu'
)
backend_test
.
exclude
(
r
'test_mean_example_cpu'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment