Commit 5cc55733 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into gather_op

parents 8d2224f1 292b6aab
#ifndef MIGRAPHX_GUARD_OPERATORS_RSQRT_HPP
#define MIGRAPHX_GUARD_OPERATORS_RSQRT_HPP
#include <migraphx/op/unary.hpp>
#include <cmath>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct rsqrt : unary<rsqrt>
{
auto apply() const
{
return [](auto x) { return 1 / std::sqrt(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -53,6 +53,7 @@ ...@@ -53,6 +53,7 @@
#include <migraphx/op/rnn.hpp> #include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp> #include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_output.hpp> #include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sinh.hpp> #include <migraphx/op/sinh.hpp>
......
...@@ -40,6 +40,7 @@ add_library(migraphx_device ...@@ -40,6 +40,7 @@ add_library(migraphx_device
device/div.cpp device/div.cpp
device/clip.cpp device/clip.cpp
device/reduce_sum.cpp device/reduce_sum.cpp
device/rsqrt.cpp
device/sqrt.cpp device/sqrt.cpp
device/reduce_mean.cpp device/reduce_mean.cpp
device/pow.cpp device/pow.cpp
......
#include <migraphx/gpu/device/rsqrt.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void rsqrt(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) __device__ { return ::rsqrt(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_RSQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_RSQRT_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void rsqrt(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_RSQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_RSQRT_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/rsqrt.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_rsqrt : unary_device<hip_rsqrt, device::rsqrt>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
#include <migraphx/gpu/convert.hpp> #include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp> #include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp> #include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/sqrt.hpp> #include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_mean.hpp> #include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/pow.hpp> #include <migraphx/gpu/pow.hpp>
...@@ -107,6 +108,7 @@ struct miopen_apply ...@@ -107,6 +108,7 @@ struct miopen_apply
add_generic_op<hip_div>("div"); add_generic_op<hip_div>("div");
add_generic_op<hip_max>("max"); add_generic_op<hip_max>("max");
add_generic_op<hip_min>("min"); add_generic_op<hip_min>("min");
add_generic_op<hip_rsqrt>("rsqrt");
add_generic_op<hip_pow>("pow"); add_generic_op<hip_pow>("pow");
add_generic_op<hip_sqdiff>("sqdiff"); add_generic_op<hip_sqdiff>("sqdiff");
......
...@@ -154,6 +154,7 @@ struct tf_parser ...@@ -154,6 +154,7 @@ struct tf_parser
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0}); add_generic_op("Relu6", op::clip{6.0, 0.0});
add_generic_op("Rsqrt", op::rsqrt{});
add_generic_op("Tanh", op::tanh{}); add_generic_op("Tanh", op::tanh{});
add_generic_op("StopGradient", op::identity{}); add_generic_op("StopGradient", op::identity{});
......
...@@ -1808,6 +1808,20 @@ TEST_CASE(reduce_sum_axis12) ...@@ -1808,6 +1808,20 @@ TEST_CASE(reduce_sum_axis12)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(rsqrt_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraphx::literal{s, {4.0, 16.0, 64.0}});
p.add_instruction(migraphx::op::rsqrt{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(reduce_mean_axis1) TEST_CASE(reduce_mean_axis1)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -3570,6 +3570,19 @@ struct test_reduce_sum_half : verify_program<test_reduce_sum_half> ...@@ -3570,6 +3570,19 @@ struct test_reduce_sum_half : verify_program<test_reduce_sum_half>
}; };
}; };
struct test_rsqrt : verify_program<test_rsqrt>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}};
auto x = p.add_parameter("x", s);
auto l0 = p.add_instruction(migraphx::op::clip{std::numeric_limits<float>::max(), 1.0}, x);
p.add_instruction(migraphx::op::rsqrt{}, l0);
return p;
};
};
struct test_reduce_mean : verify_program<test_reduce_mean> struct test_reduce_mean : verify_program<test_reduce_mean>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
:
0 Placeholder*
shape:*
dtype0

rsqrtRsqrt0*
T0"
\ No newline at end of file
...@@ -367,6 +367,16 @@ TEST_CASE(reshape_test) ...@@ -367,6 +367,16 @@ TEST_CASE(reshape_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(rsqrt_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::rsqrt{}, l0);
auto prog = optimize_tf("rsqrt_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(softmax_test) TEST_CASE(softmax_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment