Unverified Commit 1e0bbd78 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix comparisons in migraphx::value class (#1146)

* Fix comparisons in migraphx::value class
parent c4b6469a
......@@ -362,7 +362,7 @@ struct value
v(this->get_##vt()); \
return; \
}
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE)
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, )
}
......@@ -434,6 +434,8 @@ struct value
void debug_print(bool show_type = false) const;
type_t get_type() const;
private:
template <class T>
std::vector<value> from_values(const T& r)
......@@ -443,7 +445,6 @@ struct value
r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); });
return v;
}
type_t get_type() const;
std::shared_ptr<value_base_impl> x;
std::string key;
};
......
......@@ -4,6 +4,7 @@
#include <migraphx/errors.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
#include <migraphx/optional.hpp>
#include <unordered_map>
#include <utility>
......@@ -417,25 +418,12 @@ value value::with_key(const std::string& pkey) const
return result;
}
template <class F, class T, class U, class Common = typename std::common_type<T, U>::type>
auto compare_common_impl(
rank<1>, F f, const std::string& keyx, const T& x, const std::string& keyy, const U& y)
{
return f(std::forward_as_tuple(keyx, Common(x)), std::forward_as_tuple(keyy, Common(y)));
}
template <class F>
auto compare_common_impl(
rank<1>, F f, const std::string& keyx, std::nullptr_t, const std::string& keyy, std::nullptr_t)
{
return f(std::forward_as_tuple(keyx, 0), std::forward_as_tuple(keyy, 0));
}
template <class F, class T, class U>
auto compare_common_impl(rank<0>, F, const std::string&, const T&, const std::string&, const U&)
template <class T>
const T& compare_decay(const T& x)
{
return false;
return x;
}
int compare_decay(std::nullptr_t) { return 0; }
template <class F>
bool compare(const value& x, const value& y, F f)
......@@ -443,7 +431,11 @@ bool compare(const value& x, const value& y, F f)
bool result = false;
x.visit_value([&](auto&& a) {
y.visit_value([&](auto&& b) {
result = compare_common_impl(rank<1>{}, f, x.get_key(), a, y.get_key(), b);
if constexpr(std::is_same<decltype(a), decltype(b)>{})
result = f(std::forward_as_tuple(x.get_key(), compare_decay(a)),
std::forward_as_tuple(y.get_key(), compare_decay(b)));
else
assert(false); // NOLINT
});
});
return result;
......@@ -462,11 +454,16 @@ bool operator==(const value& x, const value& y)
return false;
return compare(x, y, std::equal_to<>{});
}
bool operator!=(const value& x, const value& y) { return !(x == y); }
bool operator<(const value& x, const value& y) { return compare(x, y, std::less<>{}); }
bool operator<=(const value& x, const value& y) { return x == y or x < y; }
bool operator!=(const value& x, const value& y) { return not(x == y); }
bool operator<(const value& x, const value& y)
{
if(x.get_type() != y.get_type())
return x.get_type() < y.get_type();
return compare(x, y, std::less<>{});
}
bool operator<=(const value& x, const value& y) { return not(x > y); }
bool operator>(const value& x, const value& y) { return y < x; }
bool operator>=(const value& x, const value& y) { return x == y or x > y; }
bool operator>=(const value& x, const value& y) { return not(x < y); }
void print_value(std::ostream& os, std::nullptr_t) { os << "null"; }
......
......@@ -68,9 +68,9 @@ struct nop
{
static std::string as_string() { return ""; }
template <class T>
static decltype(auto) call(T&& x)
static auto call(T&& x)
{
return x;
return static_cast<T&&>(x);
}
};
......@@ -113,6 +113,33 @@ inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.
return s;
}
template <class T>
const T& get_value(const T& x)
{
return x;
}
template <class T, class Operator = nop>
struct lhs_expression;
template <class T>
lhs_expression<T> make_lhs_expression(T&& lhs);
template <class T, class Operator>
lhs_expression<T, Operator> make_lhs_expression(T&& lhs, Operator);
// NOLINTNEXTLINE
#define TEST_EXPR_BINARY_OPERATOR(op, name) \
template <class V> \
auto operator op(const V& rhs2) const \
{ \
return make_expression(*this, rhs2, name{}); /* NOLINT */ \
}
// NOLINTNEXTLINE
#define TEST_EXPR_UNARY_OPERATOR(op, name) \
auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ }
template <class T, class U, class Operator>
struct expression
{
......@@ -125,7 +152,12 @@ struct expression
return s;
}
decltype(auto) value() const { return Operator::call(lhs, rhs); };
friend decltype(auto) get_value(const expression& e) { return e.value(); }
decltype(auto) value() const { return Operator::call(get_value(lhs), get_value(rhs)); };
TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR)
TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR)
};
// TODO: Remove rvalue references
......@@ -135,9 +167,6 @@ expression<T, U, Operator> make_expression(T&& rhs, U&& lhs, Operator)
return {std::forward<T>(rhs), std::forward<U>(lhs)};
}
template <class T, class Operator = nop>
struct lhs_expression;
// TODO: Remove rvalue reference
template <class T>
lhs_expression<T> make_lhs_expression(T&& lhs)
......@@ -166,22 +195,12 @@ struct lhs_expression
return s;
}
decltype(auto) value() const { return Operator::call(lhs); }
// NOLINTNEXTLINE
#define TEST_LHS_BINARY_OPERATOR(op, name) \
template <class U> \
auto operator op(const U& rhs) const \
{ \
return make_expression(lhs, rhs, name{}); /* NOLINT */ \
}
friend decltype(auto) get_value(const lhs_expression& e) { return e.value(); }
TEST_FOREACH_BINARY_OPERATORS(TEST_LHS_BINARY_OPERATOR)
decltype(auto) value() const { return Operator::call(get_value(lhs)); }
// NOLINTNEXTLINE
#define TEST_LHS_UNARY_OPERATOR(op, name) \
auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ }
TEST_FOREACH_UNARY_OPERATORS(TEST_LHS_UNARY_OPERATOR)
TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR)
TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR)
// NOLINTNEXTLINE
#define TEST_LHS_REOPERATOR(op) \
......@@ -223,6 +242,13 @@ auto make_predicate(const std::string& msg, F f)
return make_lhs_expression(predicate<F>{msg, f}, function{});
}
inline std::string as_string(bool x)
{
if(x)
return "true";
return "false";
}
template <class T>
std::string as_string(const T& x)
{
......@@ -627,18 +653,21 @@ inline void run(int argc, const char* argv[])
} // namespace test
// NOLINTNEXTLINE
#define TEST_CAPTURE(...) test::capture{}->*__VA_ARGS__
// NOLINTNEXTLINE
#define CHECK(...) \
test::failed( \
test::capture{}->*__VA_ARGS__, #__VA_ARGS__, __PRETTY_FUNCTION__, __FILE__, __LINE__, [] { \
})
// NOLINTNEXTLINE
#define EXPECT(...) \
test::failed(test::capture{}->*__VA_ARGS__, \
#__VA_ARGS__, \
__PRETTY_FUNCTION__, \
__FILE__, \
__LINE__, \
#define EXPECT(...) \
test::failed(TEST_CAPTURE(__VA_ARGS__), \
#__VA_ARGS__, \
__PRETTY_FUNCTION__, \
__FILE__, \
__LINE__, \
&test::fail)
// NOLINTNEXTLINE
#define STATUS(...) EXPECT((__VA_ARGS__) == 0)
......
......@@ -540,6 +540,14 @@ TEST_CASE(value_construct_object_string_mixed_value)
EXPECT(v.at("two").get_int64() == 2);
}
template <class Expression>
auto compare_predicate(const Expression& e)
{
bool result = e.value();
return test::make_predicate(test::as_string(e) + " => " + test::as_string(result),
[=] { return result; });
}
TEST_CASE(value_compare)
{
EXPECT(migraphx::value(1) == migraphx::value(1));
......@@ -553,6 +561,46 @@ TEST_CASE(value_compare)
EXPECT(migraphx::value(2) > migraphx::value(1));
EXPECT(migraphx::value(2) >= migraphx::value(1));
EXPECT(migraphx::value(1) >= migraphx::value(1));
EXPECT(migraphx::value(1) != migraphx::value("1"));
EXPECT(migraphx::value(1) != migraphx::value());
}
// NOLINTNEXTLINE
#define MIGRAPHX_VALUE_TEST_COMPARE(...) compare_predicate(TEST_CAPTURE(__VA_ARGS__))
// NOLINTNEXTLINE
#define EXPECT_TOTALLY_ORDERED_IMPL(_, x, y) \
EXPECT(_(x <= y) or _(x >= y)); \
EXPECT(_(x < y) or _(x > y) or _(x == y)); \
EXPECT((_(x < y) or _(x > y)) == _(x != y)); \
EXPECT(_(x < y) == _(y > x)); \
EXPECT(_(x <= y) == _(y >= x)); \
EXPECT(_(x < y) != _(x >= y)); \
EXPECT(_(x > y) != _(x <= y)); \
EXPECT(_(x == y) != _(x != y))
// NOLINTNEXTLINE
#define EXPECT_TOTALLY_ORDERED(x, y) \
EXPECT_TOTALLY_ORDERED_IMPL(MIGRAPHX_VALUE_TEST_COMPARE, x, y); \
EXPECT_TOTALLY_ORDERED_IMPL(MIGRAPHX_VALUE_TEST_COMPARE, y, x)
// NOLINTNEXTLINE(readability-function-size)
TEST_CASE(value_compare_ordered)
{
EXPECT_TOTALLY_ORDERED(migraphx::value(), migraphx::value());
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value(1));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value(2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", 1));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", "2"));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", "2"));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{1}), migraphx::value(std::uint64_t{1}));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{1}), migraphx::value(std::uint64_t{2}));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{2}), migraphx::value(std::uint64_t{1}));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value("1"));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value());
}
TEST_CASE(value_to_from_string)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment