Unverified Commit 1039011a authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix bugs in float_equal and add tests (#512)



* Fix bugs in float_equal and add tests

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 0928c6cb
......@@ -10,6 +10,7 @@
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp>
#include <migraphx/type_traits.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -19,7 +20,7 @@ using common_type = typename std::common_type<Ts...>::type;
struct float_equal_fn
{
template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})>
template <class T, MIGRAPHX_REQUIRES(is_floating_point<T>{})>
static bool apply(T x, T y)
{
return std::isfinite(x) and std::isfinite(y) and
......@@ -27,7 +28,7 @@ struct float_equal_fn
std::nextafter(x, std::numeric_limits<T>::max()) >= y;
}
template <class T, MIGRAPHX_REQUIRES(not std::is_floating_point<T>{})>
template <class T, MIGRAPHX_REQUIRES(not is_floating_point<T>{})>
static bool apply(T x, T y)
{
return x == y;
......
......@@ -36,4 +36,24 @@ using deduce = typename detail::deduce<T>::type;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
namespace std {
template <class T>
struct common_type<migraphx::half, T> : std::common_type<float, T>
{
};
template <class T>
struct common_type<T, migraphx::half> : std::common_type<float, T>
{
};
template <>
struct common_type<migraphx::half, migraphx::half>
{
using type = migraphx::half;
};
} // namespace std
#endif
#include <migraphx/float_equal.hpp>
#include <migraphx/half.hpp>
#include "test.hpp"
#include <limits>
template <class T, class U>
struct float_equal_expression
{
T lhs;
U rhs;
operator bool() const { return migraphx::float_equal(lhs, rhs); }
bool operator not() const { return not bool(*this); }
friend std::ostream& operator<<(std::ostream& s, const float_equal_expression& self)
{
s << "migraphx::float_equal(" << self.lhs << ", " << self.rhs << ")";
return s;
}
};
template <class T, class U>
auto test_float_equal(T x, U y)
{
return test::make_lhs_expression(float_equal_expression<T, U>{x, y});
}
template <class T, class U>
void test_equality()
{
auto x1 = T(0.1);
auto x2 = U(0.0);
auto x3 = U(1.0);
EXPECT(test_float_equal(x1, x1));
EXPECT(test_float_equal(x2, x2));
EXPECT(test_float_equal(x3, x3));
EXPECT(not test_float_equal(x1, x2));
EXPECT(not test_float_equal(x2, x1));
EXPECT(not test_float_equal(x1, x3));
EXPECT(not test_float_equal(x3, x1));
EXPECT(not test_float_equal(x2, x3));
EXPECT(not test_float_equal(x3, x2));
}
TEST_CASE_REGISTER(test_equality<double, float>);
TEST_CASE_REGISTER(test_equality<double, int>);
TEST_CASE_REGISTER(test_equality<double, migraphx::half>);
TEST_CASE_REGISTER(test_equality<float, int>);
TEST_CASE_REGISTER(test_equality<migraphx::half, int>);
template <class T, class U>
void test_limits()
{
auto max1 = std::numeric_limits<T>::max();
auto max2 = std::numeric_limits<U>::max();
auto min1 = std::numeric_limits<T>::lowest();
auto min2 = std::numeric_limits<U>::lowest();
EXPECT(test_float_equal(max1, max1));
EXPECT(test_float_equal(max2, max2));
EXPECT(not test_float_equal(max1, max2));
EXPECT(not test_float_equal(max2, max1));
EXPECT(test_float_equal(min1, min1));
EXPECT(test_float_equal(min2, min2));
EXPECT(not test_float_equal(min1, min2));
EXPECT(not test_float_equal(min2, min1));
EXPECT(not test_float_equal(max1, min1));
EXPECT(not test_float_equal(min1, max1));
EXPECT(not test_float_equal(max2, min2));
EXPECT(not test_float_equal(min2, max2));
EXPECT(not test_float_equal(max1, min2));
EXPECT(not test_float_equal(min2, max1));
EXPECT(not test_float_equal(max2, min1));
EXPECT(not test_float_equal(min1, max2));
}
TEST_CASE_REGISTER(test_limits<double, float>);
TEST_CASE_REGISTER(test_limits<double, int>);
TEST_CASE_REGISTER(test_limits<double, migraphx::half>);
TEST_CASE_REGISTER(test_limits<float, int>);
TEST_CASE_REGISTER(test_limits<int, migraphx::half>);
TEST_CASE_REGISTER(test_limits<long, int>);
TEST_CASE_REGISTER(test_limits<long, char>);
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -11,13 +11,27 @@
#define MIGRAPHX_GUARD_TEST_TEST_HPP
namespace test {
// clang-format off
// NOLINTNEXTLINE
#define TEST_FOREACH_OPERATOR(m) \
m(==, equal) m(!=, not_equal) m(<=, less_than_equal) m(>=, greater_than_equal) m(<, less_than) \
m(>, greater_than)
#define TEST_FOREACH_BINARY_OPERATORS(m) \
m(==, equal) \
m(!=, not_equal) \
m(<=, less_than_equal) \
m(>=, greater_than_equal) \
m(<, less_than) \
m(>, greater_than) \
m(and, and_op) \
m(or, or_op)
// clang-format on
// clang-format off
// NOLINTNEXTLINE
#define TEST_FOREACH_UNARY_OPERATORS(m) \
m(not, not_op)
// clang-format on
// NOLINTNEXTLINE
#define TEST_EACH_OPERATOR_OBJECT(op, name) \
#define TEST_EACH_BINARY_OPERATOR_OBJECT(op, name) \
struct name \
{ \
static std::string as_string() { return #op; } \
......@@ -28,7 +42,30 @@ namespace test {
} \
};
TEST_FOREACH_OPERATOR(TEST_EACH_OPERATOR_OBJECT)
// NOLINTNEXTLINE
#define TEST_EACH_UNARY_OPERATOR_OBJECT(op, name) \
struct name \
{ \
static std::string as_string() { return #op; } \
template <class T> \
static decltype(auto) call(T&& x) \
{ \
return op x; \
} \
};
TEST_FOREACH_BINARY_OPERATORS(TEST_EACH_BINARY_OPERATOR_OBJECT)
TEST_FOREACH_UNARY_OPERATORS(TEST_EACH_UNARY_OPERATOR_OBJECT)
struct nop
{
static std::string as_string() { return ""; }
template <class T>
static decltype(auto) call(T&& x)
{
return x;
}
};
inline std::ostream& operator<<(std::ostream& s, std::nullptr_t)
{
......@@ -56,7 +93,7 @@ struct expression
friend std::ostream& operator<<(std::ostream& s, const expression& self)
{
s << " [ " << self.lhs << " " << Operator::as_string() << " " << self.rhs << " ]";
s << self.lhs << " " << Operator::as_string() << " " << self.rhs;
return s;
}
......@@ -70,7 +107,7 @@ expression<T, U, Operator> make_expression(T&& rhs, U&& lhs, Operator)
return {std::forward<T>(rhs), std::forward<U>(lhs)};
}
template <class T>
template <class T, class Operator = nop>
struct lhs_expression;
// TODO: Remove rvalue reference
......@@ -80,7 +117,13 @@ lhs_expression<T> make_lhs_expression(T&& lhs)
return lhs_expression<T>{std::forward<T>(lhs)};
}
template <class T>
template <class T, class Operator>
lhs_expression<T, Operator> make_lhs_expression(T&& lhs, Operator)
{
return lhs_expression<T, Operator>{std::forward<T>(lhs)};
}
template <class T, class Operator>
struct lhs_expression
{
T lhs;
......@@ -88,20 +131,27 @@ struct lhs_expression
friend std::ostream& operator<<(std::ostream& s, const lhs_expression& self)
{
s << self.lhs;
s << Operator::as_string() << " " << self.lhs;
return s;
}
T value() const { return lhs; }
decltype(auto) value() const { return Operator::call(lhs); }
// NOLINTNEXTLINE
#define TEST_LHS_OPERATOR(op, name) \
#define TEST_LHS_BINARY_OPERATOR(op, name) \
template <class U> \
auto operator op(const U& rhs) const \
{ \
return make_expression(lhs, rhs, name{}); /* NOLINT */ \
}
TEST_FOREACH_OPERATOR(TEST_LHS_OPERATOR)
TEST_FOREACH_BINARY_OPERATORS(TEST_LHS_BINARY_OPERATOR)
// NOLINTNEXTLINE
#define TEST_LHS_UNARY_OPERATOR(op, name) \
auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ }
TEST_FOREACH_UNARY_OPERATORS(TEST_LHS_UNARY_OPERATOR)
// NOLINTNEXTLINE
#define TEST_LHS_REOPERATOR(op) \
template <class U> \
......@@ -117,8 +167,6 @@ struct lhs_expression
TEST_LHS_REOPERATOR(&)
TEST_LHS_REOPERATOR(|)
TEST_LHS_REOPERATOR (^)
TEST_LHS_REOPERATOR(&&)
TEST_LHS_REOPERATOR(||)
};
struct capture
......@@ -128,16 +176,23 @@ struct capture
{
return make_lhs_expression(x);
}
template <class T, class Operator>
auto operator->*(const lhs_expression<T, Operator>& x) const
{
return x;
}
};
template <class T, class F>
void failed(T x, const char* msg, const char* func, const char* file, int line, F f)
{
if(!x.value())
if(!bool(x.value()))
{
std::cout << func << std::endl;
std::cout << file << ":" << line << ":" << std::endl;
std::cout << " FAILED: " << msg << " " << x << std::endl;
std::cout << " FAILED: " << msg << " "
<< "[ " << x << " ]" << std::endl;
f();
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment