Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
60aa1c85
Unverified
Commit
60aa1c85
authored
Jan 21, 2022
by
turneram
Committed by
GitHub
Jan 21, 2022
Browse files
GreaterOrEqual ONNX parser (#1044)
Add onnx parser for operator GreaterOrEqual
parent
ebb15dd3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
102 additions
and
4 deletions
+102
-4
src/onnx/parse_greaterorequal.cpp
src/onnx/parse_greaterorequal.cpp
+31
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+16
-0
test/onnx/greaterorequal_test.onnx
test/onnx/greaterorequal_test.onnx
+16
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+18
-0
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+21
-0
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-4
No files found.
src/onnx/parse_greaterorequal.cpp
0 → 100644
View file @
60aa1c85
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_greaterorequal
:
op_parser
<
parse_greaterorequal
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"GreaterOrEqual"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
in_res
=
info
.
add_broadcastable_binary_op
(
"less"
,
args
[
0
],
args
[
1
]);
if
(
in_res
->
get_shape
().
type
()
!=
shape
::
bool_type
)
{
in_res
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
bool_type
}}),
in_res
);
}
return
info
.
add_instruction
(
make_op
(
"not"
),
in_res
);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
test/onnx/gen_onnx.py
View file @
60aa1c85
...
@@ -1618,6 +1618,22 @@ def greater_bool_test():
...
@@ -1618,6 +1618,22 @@ def greater_bool_test():
return
([
node1
,
node2
],
[
x1
,
x2
],
[
y
])
return
([
node1
,
node2
],
[
x1
,
x2
],
[
y
])
@
onnx_test
def
greaterorequal_test
():
x1
=
helper
.
make_tensor_value_info
(
'x1'
,
TensorProto
.
FLOAT
,
[
3
])
x2
=
helper
.
make_tensor_value_info
(
'x2'
,
TensorProto
.
FLOAT
,
[
3
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
3
])
node
=
onnx
.
helper
.
make_node
(
'GreaterOrEqual'
,
inputs
=
[
'x1'
,
'x2'
],
outputs
=
[
'y'
],
)
return
([
node
],
[
x1
,
x2
],
[
y
])
@
onnx_test
@
onnx_test
def
group_conv_test
():
def
group_conv_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
1
,
4
,
16
,
16
])
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
1
,
4
,
16
,
16
])
...
...
test/onnx/greaterorequal_test.onnx
0 → 100644
View file @
60aa1c85
greaterorequal_test:g
x1
x2y"GreaterOrEqualgreaterorequal_testZ
x1
Z
x2
b
y
B
\ No newline at end of file
test/onnx/onnx_test.cpp
View file @
60aa1c85
...
@@ -1549,6 +1549,24 @@ TEST_CASE(greater_bool_test)
...
@@ -1549,6 +1549,24 @@ TEST_CASE(greater_bool_test)
EXPECT(p == prog);
EXPECT(p == prog);
}
}
TEST_CASE(greaterorequal_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input1 = mm->add_parameter("x1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {3}});
auto temp = mm->add_instruction(migraphx::make_op("less"), input1, input2);
auto bt = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), temp);
auto ge = mm->add_instruction(migraphx::make_op("not"), bt);
mm->add_return({ge});
auto prog = migraphx::parse_onnx("greaterorequal_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(group_conv_test)
TEST_CASE(group_conv_test)
{
{
migraphx::program p;
migraphx::program p;
...
...
test/onnx/verify_onnx.cpp
View file @
60aa1c85
...
@@ -126,6 +126,27 @@ TEST_CASE(gather_elements)
...
@@ -126,6 +126,27 @@ TEST_CASE(gather_elements)
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
}
}
TEST_CASE
(
greaterorequal_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"greaterorequal_test.onnx"
);
p
.
compile
(
migraphx
::
ref
::
target
{});
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
std
::
vector
<
float
>
data1
=
{
0.25
,
0.75
,
0.9375
};
std
::
vector
<
float
>
data2
=
{
0.25
,
0.74
,
0.9411
};
migraphx
::
parameter_map
pp
;
pp
[
"x1"
]
=
migraphx
::
argument
(
s
,
data1
.
data
());
pp
[
"x2"
]
=
migraphx
::
argument
(
s
,
data2
.
data
());
auto
result
=
p
.
eval
(
pp
).
back
();
std
::
vector
<
float
>
result_vector
;
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
1.0
,
1.0
,
0.0
};
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
}
TEST_CASE
(
hardsigmoid_verify_test
)
TEST_CASE
(
hardsigmoid_verify_test
)
{
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"hardsigmoid_verify_test.onnx"
);
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"hardsigmoid_verify_test.onnx"
);
...
...
test/py/onnx_backend_test.py
View file @
60aa1c85
...
@@ -266,10 +266,6 @@ def create_backend_test(testname=None, target_device=None):
...
@@ -266,10 +266,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test
.
exclude
(
r
'test_gathernd_example_float32_cpu'
)
backend_test
.
exclude
(
r
'test_gathernd_example_float32_cpu'
)
backend_test
.
exclude
(
r
'test_gathernd_example_int32_batch_dim1_cpu'
)
backend_test
.
exclude
(
r
'test_gathernd_example_int32_batch_dim1_cpu'
)
backend_test
.
exclude
(
r
'test_gathernd_example_int32_cpu'
)
backend_test
.
exclude
(
r
'test_gathernd_example_int32_cpu'
)
backend_test
.
exclude
(
r
'test_greater_equal_bcast_cpu'
)
backend_test
.
exclude
(
r
'test_greater_equal_bcast_expanded_cpu'
)
backend_test
.
exclude
(
r
'test_greater_equal_cpu'
)
backend_test
.
exclude
(
r
'test_greater_equal_expanded_cpu'
)
backend_test
.
exclude
(
r
'test_identity_sequence_cpu'
)
backend_test
.
exclude
(
r
'test_identity_sequence_cpu'
)
backend_test
.
exclude
(
r
'test_maxpool_2d_uint8_cpu'
)
backend_test
.
exclude
(
r
'test_maxpool_2d_uint8_cpu'
)
backend_test
.
exclude
(
r
'test_mean_example_cpu'
)
backend_test
.
exclude
(
r
'test_mean_example_cpu'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment