Commit 0adbc8d4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add unit test for the round operator

parent d3d6a51e
...@@ -55,6 +55,7 @@ struct onnx_parser ...@@ -55,6 +55,7 @@ struct onnx_parser
add_generic_op("Acos", op::acos{}); add_generic_op("Acos", op::acos{});
add_generic_op("Atan", op::atan{}); add_generic_op("Atan", op::atan{});
add_generic_op("Sqrt", op::sqrt{}); add_generic_op("Sqrt", op::sqrt{});
add_generic_op("Round", op::round{});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Div", op::div{}); add_binary_op("Div", op::div{});
......
...@@ -1912,4 +1912,24 @@ TEST_CASE(sqdiff_test) ...@@ -1912,4 +1912,24 @@ TEST_CASE(sqdiff_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(round_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {9}};
auto l = p.add_literal(
migraphx::literal{s, {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}});
p.add_instruction(migraphx::op::round{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
for (auto v : results_vector) {
std::cout << v << "\t";
}
std::cout << std::endl;
std::vector<float> gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -3619,4 +3619,16 @@ struct test_reduce_mean_half : verify_program<test_reduce_mean_half> ...@@ -3619,4 +3619,16 @@ struct test_reduce_mean_half : verify_program<test_reduce_mean_half>
}; };
}; };
struct test_round : verify_program<test_round>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
p.add_instruction(migraphx::op::round{}, param);
return p;
}
};
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1012,4 +1012,15 @@ TEST_CASE(expand_test) ...@@ -1012,4 +1012,15 @@ TEST_CASE(expand_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(round_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
p.add_instruction(migraphx::op::round{}, input);
auto prog = migraphx::parse_onnx("round_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
 round-example:E
xy"Round
test_roundZ
x
 

b
y
 

B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment