Commit 9f375dfc authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add tests for the operator erf

parent 5caf0b33
...@@ -40,6 +40,7 @@ struct onnx_parser ...@@ -40,6 +40,7 @@ struct onnx_parser
add_generic_op("Sigmoid", op::sigmoid{}); add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Abs", op::abs{}); add_generic_op("Abs", op::abs{});
add_generic_op("Exp", op::exp{}); add_generic_op("Exp", op::exp{});
add_generic_op("Erf", op::erf{});
add_generic_op("Log", op::log{}); add_generic_op("Log", op::log{});
// disable dropout for inference // disable dropout for inference
add_generic_op("Dropout", op::identity{}); add_generic_op("Dropout", op::identity{});
......
...@@ -15,6 +15,7 @@ add_library(migraphx_device ...@@ -15,6 +15,7 @@ add_library(migraphx_device
device/max.cpp device/max.cpp
device/min.cpp device/min.cpp
device/exp.cpp device/exp.cpp
device/erf.cpp
device/log.cpp device/log.cpp
device/sin.cpp device/sin.cpp
device/cos.cpp device/cos.cpp
......
...@@ -85,7 +85,7 @@ struct miopen_apply ...@@ -85,7 +85,7 @@ struct miopen_apply
add_generic_op<hip_add>("add"); add_generic_op<hip_add>("add");
add_generic_op<hip_sub>("sub"); add_generic_op<hip_sub>("sub");
add_generic_op<hip_exp>("exp"); add_generic_op<hip_exp>("exp");
add_generic_op<hip_exp>("erf"); add_generic_op<hip_erf>("erf");
add_generic_op<hip_log>("log"); add_generic_op<hip_log>("log");
add_generic_op<hip_sin>("sin"); add_generic_op<hip_sin>("sin");
add_generic_op<hip_cos>("cos"); add_generic_op<hip_cos>("cos");
......
...@@ -527,6 +527,20 @@ TEST_CASE(exp_test) ...@@ -527,6 +527,20 @@ TEST_CASE(exp_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(erf_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {4}};
auto l = p.add_literal(migraphx::literal{s, {0.73785057, 1.58165966, -0.43597795, -0.01677432}});
p.add_instruction(migraphx::op::erf{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.70327317, 0.97470088, -0.46247893, -0.01892602};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(log_test) TEST_CASE(log_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -243,6 +243,18 @@ struct test_exp : verify_program<test_exp> ...@@ -243,6 +243,18 @@ struct test_exp : verify_program<test_exp>
} }
}; };
struct test_erf : verify_program<test_erf>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
p.add_instruction(migraphx::op::erf{}, param);
return p;
}
};
struct test_log : verify_program<test_log> struct test_log : verify_program<test_log>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -192,6 +192,16 @@ TEST_CASE(exp_test) ...@@ -192,6 +192,16 @@ TEST_CASE(exp_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(erf_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
p.add_instruction(migraphx::op::erf{}, input);
auto prog = migraphx::parse_onnx("erf_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(log_test) TEST_CASE(log_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment