Commit f225e19a authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into transpose_op

parents 8e2212ce 03f5c679
#ifndef MIGRAPHX_GUARD_OPERATORS_SQDIFF_HPP
#define MIGRAPHX_GUARD_OPERATORS_SQDIFF_HPP
#include <migraphx/op/binary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sqdiff : binary<sqdiff>
{
auto apply() const
{
return [](auto x, auto y) { return (x - y) * (x - y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SQRT_HPP
#define MIGRAPHX_GUARD_OPERATORS_SQRT_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sqrt : unary<sqrt>
{
auto apply() const
{
return [](auto x) { return std::sqrt(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -58,6 +58,8 @@ ...@@ -58,6 +58,8 @@
#include <migraphx/op/sin.hpp> #include <migraphx/op/sin.hpp>
#include <migraphx/op/slice.hpp> #include <migraphx/op/slice.hpp>
#include <migraphx/op/softmax.hpp> #include <migraphx/op/softmax.hpp>
#include <migraphx/op/sqrt.hpp>
#include <migraphx/op/sqdiff.hpp>
#include <migraphx/op/squeeze.hpp> #include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp> #include <migraphx/op/sub.hpp>
#include <migraphx/op/tanh.hpp> #include <migraphx/op/tanh.hpp>
......
...@@ -54,6 +54,7 @@ struct onnx_parser ...@@ -54,6 +54,7 @@ struct onnx_parser
add_generic_op("Asin", op::asin{}); add_generic_op("Asin", op::asin{});
add_generic_op("Acos", op::acos{}); add_generic_op("Acos", op::acos{});
add_generic_op("Atan", op::atan{}); add_generic_op("Atan", op::atan{});
add_generic_op("Sqrt", op::sqrt{});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Div", op::div{}); add_binary_op("Div", op::div{});
......
...@@ -40,7 +40,9 @@ add_library(migraphx_device ...@@ -40,7 +40,9 @@ add_library(migraphx_device
device/div.cpp device/div.cpp
device/clip.cpp device/clip.cpp
device/reduce_sum.cpp device/reduce_sum.cpp
device/sqrt.cpp
device/reduce_mean.cpp device/reduce_mean.cpp
device/sqdiff.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device) rocm_clang_tidy_check(migraphx_device)
......
#include <migraphx/gpu/device/sqdiff.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqdiff(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) { return (x - y) * (x - y); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/sqrt.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqrt(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return ::sqrt(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SQDIFF_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SQDIFF_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqdiff(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SQRT_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqrt(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SQDIFF_HPP
#define MIGRAPHX_GUARD_RTGLIB_SQDIFF_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/sqdiff.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_sqdiff : binary_device<hip_sqdiff, device::sqdiff>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SQRT_HPP
#define MIGRAPHX_GUARD_RTGLIB_SQRT_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/sqrt.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_sqrt : unary_device<hip_sqrt, device::sqrt>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -52,7 +52,9 @@ ...@@ -52,7 +52,9 @@
#include <migraphx/gpu/convert.hpp> #include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp> #include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp> #include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_mean.hpp> #include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <algorithm> #include <algorithm>
...@@ -99,10 +101,12 @@ struct miopen_apply ...@@ -99,10 +101,12 @@ struct miopen_apply
add_generic_op<hip_asin>("asin"); add_generic_op<hip_asin>("asin");
add_generic_op<hip_acos>("acos"); add_generic_op<hip_acos>("acos");
add_generic_op<hip_atan>("atan"); add_generic_op<hip_atan>("atan");
add_generic_op<hip_sqrt>("sqrt");
add_generic_op<hip_mul>("mul"); add_generic_op<hip_mul>("mul");
add_generic_op<hip_div>("div"); add_generic_op<hip_div>("div");
add_generic_op<hip_max>("max"); add_generic_op<hip_max>("max");
add_generic_op<hip_min>("min"); add_generic_op<hip_min>("min");
add_generic_op<hip_sqdiff>("sqdiff");
add_extend_op<miopen_gemm, op::dot>("dot"); add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
......
...@@ -159,6 +159,7 @@ struct tf_parser ...@@ -159,6 +159,7 @@ struct tf_parser
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{}); add_binary_op("Mul", op::mul{});
add_binary_op("SquaredDifference", op::sqdiff{});
add_binary_op("Sub", op::sub{}); add_binary_op("Sub", op::sub{});
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
......
...@@ -542,6 +542,21 @@ TEST_CASE(erf_test) ...@@ -542,6 +542,21 @@ TEST_CASE(erf_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(sqrt_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {5}};
auto l = p.add_literal(
migraphx::literal{s, {1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184}});
p.add_instruction(migraphx::op::sqrt{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.01233218, 0.92543537, 0.18450265, 0.96328566, 0.32510282};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(log_test) TEST_CASE(log_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -1853,4 +1868,19 @@ TEST_CASE(reduce_mean_int) ...@@ -1853,4 +1868,19 @@ TEST_CASE(reduce_mean_int)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(sqdiff_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = p.add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraphx::op::sqdiff{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {4, 4, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -255,6 +255,19 @@ struct test_erf : verify_program<test_erf> ...@@ -255,6 +255,19 @@ struct test_erf : verify_program<test_erf>
} }
}; };
struct test_sqrt : verify_program<test_sqrt>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
auto param_abs = p.add_instruction(migraphx::op::abs{}, param);
p.add_instruction(migraphx::op::sqrt{}, param_abs);
return p;
}
};
struct test_log : verify_program<test_log> struct test_log : verify_program<test_log>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -202,6 +202,16 @@ TEST_CASE(erf_test) ...@@ -202,6 +202,16 @@ TEST_CASE(erf_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(sqrt_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
p.add_instruction(migraphx::op::sqrt{}, input);
auto prog = migraphx::parse_onnx("sqrt_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(log_test) TEST_CASE(log_test)
{ {
migraphx::program p; migraphx::program p;
......
 sqrt-example:C
xy"Sqrt test_sqrtZ
x


b
y


B
:
0 Placeholder*
dtype0*
shape:
:
1 Placeholder*
dtype0*
shape:
*
sqdiffSquaredDifference01*
T0"
\ No newline at end of file
...@@ -364,6 +364,17 @@ TEST_CASE(softmax_test) ...@@ -364,6 +364,17 @@ TEST_CASE(softmax_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(sqdiff_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::sqdiff{}, l0, l1);
auto prog = optimize_tf("sqdiff_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(squeeze_test) TEST_CASE(squeeze_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment