Unverified Commit 49e65e08 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #317 from ROCmSoftwarePlatform/pow_operator

Pow_operator
parents ff0090da ff25a0bb
#ifndef MIGRAPHX_GUARD_OPERATORS_POW_HPP
#define MIGRAPHX_GUARD_OPERATORS_POW_HPP
#include <migraphx/op/binary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct pow : binary<pow>
{
auto apply() const
{
return [](auto x, auto y) { return std::pow(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -45,6 +45,7 @@
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/relu.hpp>
......
......@@ -60,6 +60,7 @@ struct onnx_parser
add_binary_op("Div", op::div{});
add_binary_op("Mul", op::mul{});
add_binary_op("Sub", op::sub{});
add_binary_op("Pow", op::pow{});
add_variadic_op("Sum", op::add{});
add_variadic_op("Max", op::max{});
......
......@@ -42,6 +42,7 @@ add_library(migraphx_device
device/reduce_sum.cpp
device/sqrt.cpp
device/reduce_mean.cpp
device/pow.cpp
device/sqdiff.cpp
)
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
......
#include <migraphx/gpu/device/pow.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void pow(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)(
[](auto b, auto e) { return ::pow(to_hip_type(b), to_hip_type(e)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_POW_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_POW_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void pow(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_POW_HPP
#define MIGRAPHX_GUARD_RTGLIB_POW_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/pow.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_pow : binary_device<hip_pow, device::pow>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -54,6 +54,7 @@
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <utility>
#include <functional>
......@@ -106,6 +107,7 @@ struct miopen_apply
add_generic_op<hip_div>("div");
add_generic_op<hip_max>("max");
add_generic_op<hip_min>("min");
add_generic_op<hip_pow>("pow");
add_generic_op<hip_sqdiff>("sqdiff");
add_extend_op<miopen_gemm, op::dot>("dot");
......
......@@ -571,6 +571,21 @@ TEST_CASE(log_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(pow_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto b = p.add_literal(migraphx::literal{s, {1, 2, 3}});
auto e = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraphx::op::pow{}, b, e);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f, 4.0f, 27.0f};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sin_test)
{
migraphx::program p;
......
......@@ -280,6 +280,20 @@ struct test_log : verify_program<test_log>
}
};
struct test_pow : verify_program<test_pow>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> vec_e(s.elements(), 2.0f);
auto b = p.add_parameter("x", s);
auto e = p.add_literal(migraphx::literal(s, vec_e));
p.add_instruction(migraphx::op::pow{}, b, e);
return p;
}
};
struct test_sin : verify_program<test_sin>
{
migraphx::program create_program() const
......
......@@ -899,4 +899,30 @@ TEST_CASE(clip_test)
EXPECT(p == prog);
}
TEST_CASE(implicit_pow_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::pow{}, l2, l3);
auto prog = migraphx::parse_onnx("pow_bcast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pow_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::pow{}, l0, l1);
auto prog = migraphx::parse_onnx("pow_bcast_test1.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
pow2:q

0
1out"Powpow_testZ
0




Z
1



b
out




B
pow2:u

0
1out"Powpow_testZ
0




Z
1




b
out




B
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment