Commit e7bcbd56 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add tests for the pow operator

parent 28b33e4a
...@@ -58,11 +58,11 @@ struct onnx_parser ...@@ -58,11 +58,11 @@ struct onnx_parser
add_binary_op("Div", op::div{}); add_binary_op("Div", op::div{});
add_binary_op("Mul", op::mul{}); add_binary_op("Mul", op::mul{});
add_binary_op("Sub", op::sub{}); add_binary_op("Sub", op::sub{});
add_binary_op("Pow", op::pow{});
add_variadic_op("Sum", op::add{}); add_variadic_op("Sum", op::add{});
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{}); add_variadic_op("Min", op::min{});
add_variadic_op("Pow", op::pow{});
add_mem_op("ArgMax", &onnx_parser::parse_argmax); add_mem_op("ArgMax", &onnx_parser::parse_argmax);
add_mem_op("ArgMin", &onnx_parser::parse_argmin); add_mem_op("ArgMin", &onnx_parser::parse_argmin);
......
...@@ -6,7 +6,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,7 +6,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void pow(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) void pow(hipStream_t stream, const argument& result, const argument& arg2, const argument& arg1)
{ {
nary(stream, result, arg1, arg2)( nary(stream, result, arg1, arg2)(
[](auto x, auto y) { return ::pow(to_hip_type(x), to_hip_type(y)); }); [](auto x, auto y) { return ::pow(to_hip_type(x), to_hip_type(y)); });
......
...@@ -541,6 +541,21 @@ TEST_CASE(log_test) ...@@ -541,6 +541,21 @@ TEST_CASE(log_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(pow_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto b = p.add_literal(migraphx::literal{s, {1, 2, 3}});
auto e = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraphx::op::pow{}, b, e);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f, 4.0f, 27.0f};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sin_test) TEST_CASE(sin_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -255,6 +255,20 @@ struct test_log : verify_program<test_log> ...@@ -255,6 +255,20 @@ struct test_log : verify_program<test_log>
} }
}; };
struct test_pow : verify_program<test_pow>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> vec_e(s.elements(), 2.0f);
auto b = p.add_parameter("x", s);
auto e = p.add_literal(migraphx::literal(s, vec_e));
p.add_instruction(migraphx::op::pow{}, b, e);
return p;
}
};
struct test_sin : verify_program<test_sin> struct test_sin : verify_program<test_sin>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -858,4 +858,32 @@ TEST_CASE(clip_test) ...@@ -858,4 +858,32 @@ TEST_CASE(clip_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(implicit_pow_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::pow{}, l2, l3);
auto prog = migraphx::parse_onnx("pow_bcast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pow_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::pow{}, l0, l1);
auto prog = migraphx::parse_onnx("pow_bcast_test1.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment