"src/libtorchaudio/CMakeLists.txt" did not exist on "72b712a1eef4f3ba292b8712e2acf15519d61378"
Unverified Commit 4563b8cd authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix handling of lowest values in pad operator (#514)



* Fix handling of lowest values in pad operator

* Formatting

* Formatting

* Formatting

* Add cpu test for lowest padding

* Add test for max
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 1a4ff504
#ifndef MIGRAPHX_GUARD_RTGLIB_CLAMP_HPP
#define MIGRAPHX_GUARD_RTGLIB_CLAMP_HPP
#include <migraphx/config.hpp>
#include <migraphx/float_equal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class U, class T>
U pad_clamp(T x)
{
if(float_equal(x, std::numeric_limits<T>::lowest()))
return std::numeric_limits<U>::lowest();
if(float_equal(x, std::numeric_limits<T>::max()))
return std::numeric_limits<U>::max();
return (x < std::numeric_limits<U>::lowest())
? std::numeric_limits<U>::lowest()
: (std::numeric_limits<U>::max() < x) ? std::numeric_limits<U>::max() : U(x);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -21,6 +21,7 @@
#include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/cpu/gemm.hpp>
#include <unordered_map>
#include <utility>
......@@ -459,7 +460,10 @@ struct cpu_pad
{
assert(output_shape.standard());
argument result{output_shape};
result.visit([&](auto output) { std::fill(output.begin(), output.end(), op.value); });
result.visit([&](auto output) {
using type = typename decltype(output)::value_type;
std::fill(output.begin(), output.end(), pad_clamp<type>(op.value));
});
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(input.get_shape(), [&](const auto& idx) {
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/pad.hpp>
#include <migraphx/gpu/device/tensor.hpp>
......@@ -18,11 +19,7 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector
hip_visit_all(result, arg1)([&](auto output, auto input) {
using type = typename decltype(output)::value_type;
using hip_index = typename decltype(output)::hip_index;
type device_val = value;
if(float_equal(value, std::numeric_limits<float>::lowest()))
{
device_val = device_cast(std::numeric_limits<type>::lowest());
}
type device_val = pad_clamp<host_type<type>>(value);
gs_launch(stream, result.get_shape().elements())(
[=](auto i) __device__ { output.data()[i] = device_val; });
......
......@@ -1993,6 +1993,36 @@ TEST_CASE(pad_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(pad_test_lowest_half)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {2, 2}};
auto l0 = p.add_literal(migraphx::literal{s, {1, 2, 3, 4}});
p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}, std::numeric_limits<float>::lowest()}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
const float x = std::numeric_limits<migraphx::half>::lowest();
std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(pad_test_highest_half)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {2, 2}};
auto l0 = p.add_literal(migraphx::literal{s, {1, 2, 3, 4}});
p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}, std::numeric_limits<float>::max()}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
const float x = std::numeric_limits<migraphx::half>::max();
std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(fp16_test)
{
migraphx::program p;
......
......@@ -18,6 +18,7 @@
#include <future>
#include <thread>
#include <numeric>
#include <test.hpp>
......@@ -2267,6 +2268,40 @@ struct test_pad_int8 : verify_program<test_pad_int8>
}
};
struct test_pad_lowest : verify_program<test_pad_lowest>
{
migraphx::program create_program() const
{
migraphx::program p;
std::vector<migraphx::half> data0(4);
std::iota(data0.begin(), data0.end(), 0);
migraphx::shape s0{migraphx::shape::half_type, {2, 2}};
auto l0 = p.add_literal(migraphx::literal{s0, data0});
migraphx::op::pad op{};
op.value = std::numeric_limits<float>::lowest();
op.pads = {0, 0, 1, 1};
p.add_instruction(op, l0);
return p;
}
};
struct test_pad_highest : verify_program<test_pad_highest>
{
migraphx::program create_program() const
{
migraphx::program p;
std::vector<migraphx::half> data0(4);
std::iota(data0.begin(), data0.end(), 0);
migraphx::shape s0{migraphx::shape::half_type, {2, 2}};
auto l0 = p.add_literal(migraphx::literal{s0, data0});
migraphx::op::pad op{};
op.value = std::numeric_limits<float>::max();
op.pads = {0, 0, 1, 1};
p.add_instruction(op, l0);
return p;
}
};
struct test_pooling_autopad : verify_program<test_pooling_autopad>
{
migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment