Unverified Commit 4d46cbdb authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Logical ops (#718)

* add the and operator

* clang format

* add unit tests for the and operator

* clang format

* change the and name to logical_and and add the logical_or, logical_xor

* clang format

* add onnx unit tests for or and xor

* add more unit tests
parent 62a1b87b
logical_xor_bcast_test:w

0
12"Xorlogical_xor_bcast_testZ
0
 



Z
1
 

b
2
 



B
\ No newline at end of file
...@@ -1599,6 +1599,52 @@ TEST_CASE(log_test) ...@@ -1599,6 +1599,52 @@ TEST_CASE(log_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(logical_and_bcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {4, 5}});
auto l2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l0->get_shape().lens()}}), l1);
auto ret = mm->add_instruction(migraphx::make_op("logical_and"), l0, l2);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("logical_and_bcast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(logical_or_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}});
auto ret = mm->add_instruction(migraphx::make_op("logical_or"), l0, l1);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("logical_or_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(logical_xor_bcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {4, 1}});
auto l2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l0->get_shape().lens()}}), l1);
auto ret = mm->add_instruction(migraphx::make_op("logical_xor"), l0, l2);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("logical_xor_bcast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(logsoftmax_test) TEST_CASE(logsoftmax_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -54,6 +54,7 @@ def create_backend_test(testname=None, target_device=None): ...@@ -54,6 +54,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_acos.*') backend_test.include(r'.*test_acos.*')
backend_test.include(r'.*test_acosh.*') backend_test.include(r'.*test_acosh.*')
backend_test.include(r'.*test_add.*') backend_test.include(r'.*test_add.*')
backend_test.include(r'.*test_and.*')
backend_test.include(r'.*test_argmax.*') backend_test.include(r'.*test_argmax.*')
backend_test.include(r'.*test_argmin.*') backend_test.include(r'.*test_argmin.*')
backend_test.include(r'.*test_asin.*') backend_test.include(r'.*test_asin.*')
...@@ -131,6 +132,7 @@ def create_backend_test(testname=None, target_device=None): ...@@ -131,6 +132,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_operator_symbolic_override.*') backend_test.include(r'.*test_operator_symbolic_override.*')
backend_test.include(r'.*test_operator_symbolic_override_nested.*') backend_test.include(r'.*test_operator_symbolic_override_nested.*')
backend_test.include(r'.*test_operator_view.*') backend_test.include(r'.*test_operator_view.*')
backend_test.include(r'.*test_or.*')
backend_test.include(r'.*test_pow.*') backend_test.include(r'.*test_pow.*')
backend_test.include(r'.*test_PoissonNLLLLoss_no_reduce*') backend_test.include(r'.*test_PoissonNLLLLoss_no_reduce*')
backend_test.include(r'.*test_quantizelinear') backend_test.include(r'.*test_quantizelinear')
...@@ -163,6 +165,7 @@ def create_backend_test(testname=None, target_device=None): ...@@ -163,6 +165,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_unsqueeze.*') backend_test.include(r'.*test_unsqueeze.*')
backend_test.include(r'.*test_where*') backend_test.include(r'.*test_where*')
backend_test.include(r'.*test_where.*') backend_test.include(r'.*test_where.*')
backend_test.include(r'.*test_xor.*')
backend_test.include(r'.*test_ZeroPad2d*') backend_test.include(r'.*test_ZeroPad2d*')
# # Onnx native model tests # # Onnx native model tests
......
...@@ -1822,6 +1822,54 @@ TEST_CASE(logsoftmax_test_axis_3) ...@@ -1822,6 +1822,54 @@ TEST_CASE(logsoftmax_test_axis_3)
EXPECT(migraphx::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
TEST_CASE(logical_and_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {4}};
auto l1 = mm->add_literal(migraphx::literal{s, {1, 0, 1, 0}});
auto l2 = mm->add_literal(migraphx::literal{s, {1, 1, 0, 0}});
mm->add_instruction(migraphx::make_op("logical_and"), l1, l2);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold = {1, 0, 0, 0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(logical_or_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {4}};
auto l1 = mm->add_literal(migraphx::literal{s, {1, 0, 1, 0}});
auto l2 = mm->add_literal(migraphx::literal{s, {1, 1, 0, 0}});
mm->add_instruction(migraphx::make_op("logical_or"), l1, l2);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold = {1, 1, 1, 0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(logical_xor_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {4}};
auto l1 = mm->add_literal(migraphx::literal{s, {1, 0, 1, 0}});
auto l2 = mm->add_literal(migraphx::literal{s, {1, 1, 0, 0}});
mm->add_instruction(migraphx::make_op("logical_xor"), l1, l2);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold = {0, 1, 1, 0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(argmax_test_0) TEST_CASE(argmax_test_0)
{ {
migraphx::program p; migraphx::program p;
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_and : verify_program<test_and>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {3}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("logical_and"), x, y);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_or : verify_program<test_or>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {3}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("logical_or"), x, y);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_xor : verify_program<test_xor>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {3}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("logical_xor"), x, y);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment