Unverified Commit 60aa1c85 authored by turneram's avatar turneram Committed by GitHub
Browse files

GreaterOrEqual ONNX parser (#1044)

Add onnx parser for operator GreaterOrEqual
parent ebb15dd3
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_greaterorequal : op_parser<parse_greaterorequal>
{
std::vector<op_desc> operators() const { return {{"GreaterOrEqual"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto in_res = info.add_broadcastable_binary_op("less", args[0], args[1]);
if(in_res->get_shape().type() != shape::bool_type)
{
in_res = info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}),
in_res);
}
return info.add_instruction(make_op("not"), in_res);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -1618,6 +1618,22 @@ def greater_bool_test():
return ([node1, node2], [x1, x2], [y])
@onnx_test
def greaterorequal_test():
x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3])
x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3])
node = onnx.helper.make_node(
'GreaterOrEqual',
inputs=['x1', 'x2'],
outputs=['y'],
)
return ([node], [x1, x2], [y])
@onnx_test
def group_conv_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 4, 16, 16])
......
greaterorequal_test:g

x1
x2y"GreaterOrEqualgreaterorequal_testZ
x1

Z
x2

b
y

B
\ No newline at end of file
......@@ -1549,6 +1549,24 @@ TEST_CASE(greater_bool_test)
EXPECT(p == prog);
}
TEST_CASE(greaterorequal_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input1 = mm->add_parameter("x1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {3}});
auto temp = mm->add_instruction(migraphx::make_op("less"), input1, input2);
auto bt = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), temp);
auto ge = mm->add_instruction(migraphx::make_op("not"), bt);
mm->add_return({ge});
auto prog = migraphx::parse_onnx("greaterorequal_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(group_conv_test)
{
migraphx::program p;
......
......@@ -126,6 +126,27 @@ TEST_CASE(gather_elements)
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(greaterorequal_test)
{
migraphx::program p = migraphx::parse_onnx("greaterorequal_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {3}};
std::vector<float> data1 = {0.25, 0.75, 0.9375};
std::vector<float> data2 = {0.25, 0.74, 0.9411};
migraphx::parameter_map pp;
pp["x1"] = migraphx::argument(s, data1.data());
pp["x2"] = migraphx::argument(s, data2.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, 0.0};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(hardsigmoid_verify_test)
{
migraphx::program p = migraphx::parse_onnx("hardsigmoid_verify_test.onnx");
......
......@@ -266,10 +266,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_gathernd_example_float32_cpu')
backend_test.exclude(r'test_gathernd_example_int32_batch_dim1_cpu')
backend_test.exclude(r'test_gathernd_example_int32_cpu')
backend_test.exclude(r'test_greater_equal_bcast_cpu')
backend_test.exclude(r'test_greater_equal_bcast_expanded_cpu')
backend_test.exclude(r'test_greater_equal_cpu')
backend_test.exclude(r'test_greater_equal_expanded_cpu')
backend_test.exclude(r'test_identity_sequence_cpu')
backend_test.exclude(r'test_maxpool_2d_uint8_cpu')
backend_test.exclude(r'test_mean_example_cpu')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment