Commit 86094d23 authored by Khalique's avatar Khalique
Browse files

manual merge

parents 6816a475 524c86ff
#ifndef MIGRAPHX_GUARD_OPERATORS_SQDIFF_HPP
#define MIGRAPHX_GUARD_OPERATORS_SQDIFF_HPP
#include <migraphx/op/binary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sqdiff : binary<sqdiff>
{
auto apply() const
{
return [](auto x, auto y) { return (x - y) * (x - y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -59,6 +59,7 @@
#include <migraphx/op/sin.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/op/sqdiff.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/tanh.hpp>
......
......@@ -42,6 +42,7 @@ add_library(migraphx_device
device/reduce_sum.cpp
device/rsqrt.cpp
device/reduce_mean.cpp
device/sqdiff.cpp
)
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device)
......
#include <migraphx/gpu/device/sqdiff.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqdiff(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) { return (x - y) * (x - y); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SQDIFF_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SQDIFF_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sqdiff(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SQDIFF_HPP
#define MIGRAPHX_GUARD_RTGLIB_SQDIFF_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/sqdiff.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_sqdiff : binary_device<hip_sqdiff, device::sqdiff>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -54,6 +54,7 @@
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <utility>
#include <functional>
#include <algorithm>
......@@ -105,6 +106,7 @@ struct miopen_apply
add_generic_op<hip_max>("max");
add_generic_op<hip_min>("min");
add_generic_op<hip_rsqrt>("rsqrt");
add_generic_op<hip_sqdiff>("sqdiff");
add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
......
......@@ -160,6 +160,7 @@ struct tf_parser
add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{});
add_binary_op("SquaredDifference", op::sqdiff{});
add_binary_op("Sub", op::sub{});
add_mem_op("AvgPool", &tf_parser::parse_pooling);
......
......@@ -1867,4 +1867,19 @@ TEST_CASE(reduce_mean_int)
EXPECT(results_vector == gold);
}
TEST_CASE(sqdiff_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = p.add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraphx::op::sqdiff{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {4, 4, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
:
0 Placeholder*
dtype0*
shape:
:
1 Placeholder*
dtype0*
shape:
*
sqdiffSquaredDifference01*
T0"
\ No newline at end of file
......@@ -374,6 +374,17 @@ TEST_CASE(softmax_test)
EXPECT(p == prog);
}
TEST_CASE(sqdiff_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::sqdiff{}, l0, l1);
auto prog = optimize_tf("sqdiff_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(squeeze_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment