"src/vscode:/vscode.git/clone" did not exist on "b75c83d8c8115ce6ee89cb50d043e8999ee08abb"
Unverified Commit 59b80d4e authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Bool type and equal operator (#603)



* add bool type

* code backup

* code backup

* clang format

* fix build warnings

* clang format

* add the equal operator

* add the equal operator

* clang format

* remove unnecessary code

* refine unit tests

* clang format

* fix review comments and a bug

* clang format

* additional changes

* clang format

* fix cppcheck error

* add bool type in c api

* fix cppcheck error

* fix review comments

* fix cppcheck error

* fix a build error related to gcc

* fix cppcheck error

* fix cppcheck error

* added the equal operator to register list

* add parsing boolean type

* clang format

* fix bool type issue for python output

* clang format

* add support for automatic multibroadcast of the equal operator

* additional unit tests for more code coverage

* clang format

* missing an onnx file
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 3eb4f775
......@@ -774,6 +774,40 @@ TEST_CASE(embedding_bag_offset_test)
EXPECT(test::throws([&] { migraphx::parse_onnx("embedding_bag_offset_test.onnx"); }));
}
TEST_CASE(equal_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
auto input1 = p.add_literal(migraphx::literal(s, data));
auto input2 = p.add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto eq = p.add_instruction(migraphx::op::equal{}, input1, input2);
auto ret = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, eq);
p.add_return({ret});
auto prog = migraphx::parse_onnx("equal_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(equal_bool_test)
{
migraphx::program p;
migraphx::shape sf{migraphx::shape::float_type, {2, 3}};
migraphx::shape sb{migraphx::shape::bool_type, {2, 3}};
auto input1 = p.add_parameter("x1", sf);
auto input2 = p.add_parameter("x2", sb);
auto cin1 = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, input1);
auto ret = p.add_instruction(migraphx::op::equal{}, cin1, input2);
p.add_return({ret});
auto prog = migraphx::parse_onnx("equal_bool_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(erf_test)
{
migraphx::program p;
......
......@@ -217,8 +217,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_depthtospace_crd_mode_example_cpu')
backend_test.exclude(r'test_depthtospace_dcr_mode_cpu')
backend_test.exclude(r'test_depthtospace_example_cpu')
backend_test.exclude(r'test_equal_bcast_cpu')
backend_test.exclude(r'test_equal_cpu')
backend_test.exclude(r'test_expand_dim_changed_cpu')
backend_test.exclude(r'test_expand_dim_unchanged_cpu')
backend_test.exclude(r'test_gather_0_cpu')
......
......@@ -6,6 +6,7 @@
// Add new types here
// clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(bool_type, bool) \
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment