"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "8202e4112b35a274b356aa171af2e35396ed9aa2"
Unverified Commit 48ffbfa5 authored by turneram's avatar turneram Committed by GitHub
Browse files

Added greater and less operators (#660)



* Added greater and less operators

* Fixed ops_test.cpp

* Set commutative to false for less, greater

* Refactored parse_equal/less/greater into parse_compare_op

* Removed unnecessary function attributes() from greater.hpp/less.hpp

* Added op_name arguments

* Removed local settings

* Formatting

* Missing comma

* Formatting

* Formatting

* Formatting

* Formatting

* Formatting

* Missing space
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 1d98fbb4
...@@ -1028,6 +1028,38 @@ TEST_CASE(globalmaxpool_test) ...@@ -1028,6 +1028,38 @@ TEST_CASE(globalmaxpool_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(greater_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
auto input1 = p.add_literal(migraphx::literal(s, data));
auto input2 = p.add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto gr = p.add_instruction(migraphx::op::greater{}, input1, input2);
auto ret = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, gr);
p.add_return({ret});
auto prog = migraphx::parse_onnx("greater_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(greater_bool_test)
{
migraphx::program p;
migraphx::shape sf{migraphx::shape::float_type, {2, 3}};
migraphx::shape sb{migraphx::shape::bool_type, {2, 3}};
auto input1 = p.add_parameter("x1", sf);
auto input2 = p.add_parameter("x2", sb);
auto cin1 = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, input1);
auto ret = p.add_instruction(migraphx::op::greater{}, cin1, input2);
p.add_return({ret});
auto prog = migraphx::parse_onnx("greater_bool_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(group_conv_test) TEST_CASE(group_conv_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -1191,6 +1223,38 @@ TEST_CASE(leaky_relu_test) ...@@ -1191,6 +1223,38 @@ TEST_CASE(leaky_relu_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(less_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
auto input1 = p.add_literal(migraphx::literal(s, data));
auto input2 = p.add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto le = p.add_instruction(migraphx::op::less{}, input1, input2);
auto ret = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, le);
p.add_return({ret});
auto prog = migraphx::parse_onnx("less_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(less_bool_test)
{
migraphx::program p;
migraphx::shape sf{migraphx::shape::float_type, {2, 3}};
migraphx::shape sb{migraphx::shape::bool_type, {2, 3}};
auto input1 = p.add_parameter("x1", sf);
auto input2 = p.add_parameter("x2", sb);
auto cin1 = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, input1);
auto ret = p.add_instruction(migraphx::op::less{}, cin1, input2);
p.add_return({ret});
auto prog = migraphx::parse_onnx("less_bool_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(log_test) TEST_CASE(log_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -225,8 +225,6 @@ def create_backend_test(testname=None, target_device=None): ...@@ -225,8 +225,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_gathernd_example_float32_cpu') backend_test.exclude(r'test_gathernd_example_float32_cpu')
backend_test.exclude(r'test_gathernd_example_int32_batch_dim1_cpu') backend_test.exclude(r'test_gathernd_example_int32_batch_dim1_cpu')
backend_test.exclude(r'test_gathernd_example_int32_cpu') backend_test.exclude(r'test_gathernd_example_int32_cpu')
backend_test.exclude(r'test_greater_bcast_cpu')
backend_test.exclude(r'test_greater_cpu')
backend_test.exclude(r'test_greater_equal_bcast_cpu') backend_test.exclude(r'test_greater_equal_bcast_cpu')
backend_test.exclude(r'test_greater_equal_bcast_expanded_cpu') backend_test.exclude(r'test_greater_equal_bcast_expanded_cpu')
backend_test.exclude(r'test_greater_equal_cpu') backend_test.exclude(r'test_greater_equal_cpu')
...@@ -234,8 +232,6 @@ def create_backend_test(testname=None, target_device=None): ...@@ -234,8 +232,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_hardsigmoid_cpu') backend_test.exclude(r'test_hardsigmoid_cpu')
backend_test.exclude(r'test_hardsigmoid_default_cpu') backend_test.exclude(r'test_hardsigmoid_default_cpu')
backend_test.exclude(r'test_hardsigmoid_example_cpu') backend_test.exclude(r'test_hardsigmoid_example_cpu')
backend_test.exclude(r'test_less_bcast_cpu')
backend_test.exclude(r'test_less_cpu')
backend_test.exclude(r'test_less_equal_bcast_cpu') backend_test.exclude(r'test_less_equal_bcast_cpu')
backend_test.exclude(r'test_less_equal_bcast_expanded_cpu') backend_test.exclude(r'test_less_equal_bcast_expanded_cpu')
backend_test.exclude(r'test_less_equal_cpu') backend_test.exclude(r'test_less_equal_cpu')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment