Unverified Commit 752d3239 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into slice_op

parents 5cf8eb23 a960abad
...@@ -32,7 +32,9 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") ...@@ -32,7 +32,9 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
endif() endif()
endif() endif()
if(CMAKE_CXX_COMPILER MATCHES ".*hcc") include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("--cuda-host-only -x hip" HAS_HIP)
if(HAS_HIP)
message(STATUS "Enable miopen backend") message(STATUS "Enable miopen backend")
set(MIGRAPHX_ENABLE_GPU On CACHE BOOL "") set(MIGRAPHX_ENABLE_GPU On CACHE BOOL "")
else() else()
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SIGN_HPP
#define MIGRAPHX_GUARD_OPERATORS_SIGN_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sign : unary<sign>
{
auto apply() const
{
return [](auto x) { return (x > 0 ? 1 : ((x < 0) ? -1 : 0)); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -56,6 +56,7 @@ ...@@ -56,6 +56,7 @@
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp> #include <migraphx/op/sinh.hpp>
#include <migraphx/op/sin.hpp> #include <migraphx/op/sin.hpp>
#include <migraphx/op/slice.hpp> #include <migraphx/op/slice.hpp>
......
...@@ -55,6 +55,7 @@ struct onnx_parser ...@@ -55,6 +55,7 @@ struct onnx_parser
add_generic_op("Acos", op::acos{}); add_generic_op("Acos", op::acos{});
add_generic_op("Atan", op::atan{}); add_generic_op("Atan", op::atan{});
add_generic_op("Sqrt", op::sqrt{}); add_generic_op("Sqrt", op::sqrt{});
add_generic_op("Sign", op::sign{});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Div", op::div{}); add_binary_op("Div", op::div{});
......
...@@ -45,6 +45,7 @@ add_library(migraphx_device ...@@ -45,6 +45,7 @@ add_library(migraphx_device
device/reduce_mean.cpp device/reduce_mean.cpp
device/pow.cpp device/pow.cpp
device/sqdiff.cpp device/sqdiff.cpp
device/sign.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device) rocm_clang_tidy_check(migraphx_device)
......
#include <migraphx/gpu/device/sign.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sign(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return (x > 0 ? 1 : ((x < 0) ? -1 : 0)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SIGN_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SIGN_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sign(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SIGN_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIGN_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/sign.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_sign : unary_device<hip_sign, device::sign>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/gpu/erf.hpp> #include <migraphx/gpu/erf.hpp>
#include <migraphx/gpu/log.hpp> #include <migraphx/gpu/log.hpp>
#include <migraphx/gpu/sin.hpp> #include <migraphx/gpu/sin.hpp>
#include <migraphx/gpu/sign.hpp>
#include <migraphx/gpu/cos.hpp> #include <migraphx/gpu/cos.hpp>
#include <migraphx/gpu/tan.hpp> #include <migraphx/gpu/tan.hpp>
#include <migraphx/gpu/sinh.hpp> #include <migraphx/gpu/sinh.hpp>
...@@ -111,6 +112,7 @@ struct miopen_apply ...@@ -111,6 +112,7 @@ struct miopen_apply
add_generic_op<hip_rsqrt>("rsqrt"); add_generic_op<hip_rsqrt>("rsqrt");
add_generic_op<hip_pow>("pow"); add_generic_op<hip_pow>("pow");
add_generic_op<hip_sqdiff>("sqdiff"); add_generic_op<hip_sqdiff>("sqdiff");
add_generic_op<hip_sign>("sign");
add_extend_op<miopen_gemm, op::dot>("dot"); add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
......
...@@ -557,6 +557,21 @@ TEST_CASE(sqrt_test) ...@@ -557,6 +557,21 @@ TEST_CASE(sqrt_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(sign_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {5}};
auto l = p.add_literal(
migraphx::literal{s, {1.02481645, 0.85643062, -0.03404123, -0.92791926, 0.0}});
p.add_instruction(migraphx::op::sign{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(log_test) TEST_CASE(log_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -268,6 +268,18 @@ struct test_sqrt : verify_program<test_sqrt> ...@@ -268,6 +268,18 @@ struct test_sqrt : verify_program<test_sqrt>
} }
}; };
struct test_sign : verify_program<test_sign>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
p.add_instruction(migraphx::op::sign{}, param);
return p;
}
};
struct test_log : verify_program<test_log> struct test_log : verify_program<test_log>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -212,6 +212,16 @@ TEST_CASE(sqrt_test) ...@@ -212,6 +212,16 @@ TEST_CASE(sqrt_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(sign_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
p.add_instruction(migraphx::op::sign{}, input);
auto prog = migraphx::parse_onnx("sign_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(log_test) TEST_CASE(log_test)
{ {
migraphx::program p; migraphx::program p;
......
 sign-example:C
xy"Sign test_signZ
x
 

b
y
 

B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment