Unverified Commit 4563b8cd authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix handling of lowest values in pad operator (#514)



* Fix handling of lowest values in pad operator

* Formatting

* Formatting

* Formatting

* Add cpu test for lowest padding

* Add test for max
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 1a4ff504
#ifndef MIGRAPHX_GUARD_RTGLIB_CLAMP_HPP
#define MIGRAPHX_GUARD_RTGLIB_CLAMP_HPP
#include <migraphx/config.hpp>
#include <migraphx/float_equal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class U, class T>
U pad_clamp(T x)
{
if(float_equal(x, std::numeric_limits<T>::lowest()))
return std::numeric_limits<U>::lowest();
if(float_equal(x, std::numeric_limits<T>::max()))
return std::numeric_limits<U>::max();
return (x < std::numeric_limits<U>::lowest())
? std::numeric_limits<U>::lowest()
: (std::numeric_limits<U>::max() < x) ? std::numeric_limits<U>::max() : U(x);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/cpu/gemm.hpp> #include <migraphx/cpu/gemm.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -459,7 +460,10 @@ struct cpu_pad ...@@ -459,7 +460,10 @@ struct cpu_pad
{ {
assert(output_shape.standard()); assert(output_shape.standard());
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { std::fill(output.begin(), output.end(), op.value); }); result.visit([&](auto output) {
using type = typename decltype(output)::value_type;
std::fill(output.begin(), output.end(), pad_clamp<type>(op.value));
});
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(input.get_shape(), [&](const auto& idx) { shape_for_each(input.get_shape(), [&](const auto& idx) {
......
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/gpu/device/nary.hpp> #include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/pad.hpp> #include <migraphx/gpu/device/pad.hpp>
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/tensor.hpp>
...@@ -18,11 +19,7 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector ...@@ -18,11 +19,7 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector
hip_visit_all(result, arg1)([&](auto output, auto input) { hip_visit_all(result, arg1)([&](auto output, auto input) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
using hip_index = typename decltype(output)::hip_index; using hip_index = typename decltype(output)::hip_index;
type device_val = value; type device_val = pad_clamp<host_type<type>>(value);
if(float_equal(value, std::numeric_limits<float>::lowest()))
{
device_val = device_cast(std::numeric_limits<type>::lowest());
}
gs_launch(stream, result.get_shape().elements())( gs_launch(stream, result.get_shape().elements())(
[=](auto i) __device__ { output.data()[i] = device_val; }); [=](auto i) __device__ { output.data()[i] = device_val; });
......
...@@ -1993,6 +1993,36 @@ TEST_CASE(pad_test) ...@@ -1993,6 +1993,36 @@ TEST_CASE(pad_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(pad_test_lowest_half)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {2, 2}};
auto l0 = p.add_literal(migraphx::literal{s, {1, 2, 3, 4}});
p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}, std::numeric_limits<float>::lowest()}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
const float x = std::numeric_limits<migraphx::half>::lowest();
std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(pad_test_highest_half)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {2, 2}};
auto l0 = p.add_literal(migraphx::literal{s, {1, 2, 3, 4}});
p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}, std::numeric_limits<float>::max()}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
const float x = std::numeric_limits<migraphx::half>::max();
std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(fp16_test) TEST_CASE(fp16_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <future> #include <future>
#include <thread> #include <thread>
#include <numeric>
#include <test.hpp> #include <test.hpp>
...@@ -2267,6 +2268,40 @@ struct test_pad_int8 : verify_program<test_pad_int8> ...@@ -2267,6 +2268,40 @@ struct test_pad_int8 : verify_program<test_pad_int8>
} }
}; };
struct test_pad_lowest : verify_program<test_pad_lowest>
{
migraphx::program create_program() const
{
migraphx::program p;
std::vector<migraphx::half> data0(4);
std::iota(data0.begin(), data0.end(), 0);
migraphx::shape s0{migraphx::shape::half_type, {2, 2}};
auto l0 = p.add_literal(migraphx::literal{s0, data0});
migraphx::op::pad op{};
op.value = std::numeric_limits<float>::lowest();
op.pads = {0, 0, 1, 1};
p.add_instruction(op, l0);
return p;
}
};
struct test_pad_highest : verify_program<test_pad_highest>
{
migraphx::program create_program() const
{
migraphx::program p;
std::vector<migraphx::half> data0(4);
std::iota(data0.begin(), data0.end(), 0);
migraphx::shape s0{migraphx::shape::half_type, {2, 2}};
auto l0 = p.add_literal(migraphx::literal{s0, data0});
migraphx::op::pad op{};
op.value = std::numeric_limits<float>::max();
op.pads = {0, 0, 1, 1};
p.add_instruction(op, l0);
return p;
}
};
struct test_pooling_autopad : verify_program<test_pooling_autopad> struct test_pooling_autopad : verify_program<test_pooling_autopad>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment