Unverified Commit 1e0bbd78 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix comparisons in migraphx::value class (#1146)

* Fix comparisons in migraphx::value class
parent c4b6469a
...@@ -362,7 +362,7 @@ struct value ...@@ -362,7 +362,7 @@ struct value
v(this->get_##vt()); \ v(this->get_##vt()); \
return; \ return; \
} }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, ) MIGRAPHX_VALUE_GENERATE_CASE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, ) MIGRAPHX_VALUE_GENERATE_CASE(object, )
} }
...@@ -434,6 +434,8 @@ struct value ...@@ -434,6 +434,8 @@ struct value
void debug_print(bool show_type = false) const; void debug_print(bool show_type = false) const;
type_t get_type() const;
private: private:
template <class T> template <class T>
std::vector<value> from_values(const T& r) std::vector<value> from_values(const T& r)
...@@ -443,7 +445,6 @@ struct value ...@@ -443,7 +445,6 @@ struct value
r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); }); r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); });
return v; return v;
} }
type_t get_type() const;
std::shared_ptr<value_base_impl> x; std::shared_ptr<value_base_impl> x;
std::string key; std::string key;
}; };
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/optional.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -417,25 +418,12 @@ value value::with_key(const std::string& pkey) const ...@@ -417,25 +418,12 @@ value value::with_key(const std::string& pkey) const
return result; return result;
} }
template <class F, class T, class U, class Common = typename std::common_type<T, U>::type> template <class T>
auto compare_common_impl( const T& compare_decay(const T& x)
rank<1>, F f, const std::string& keyx, const T& x, const std::string& keyy, const U& y)
{
return f(std::forward_as_tuple(keyx, Common(x)), std::forward_as_tuple(keyy, Common(y)));
}
template <class F>
auto compare_common_impl(
rank<1>, F f, const std::string& keyx, std::nullptr_t, const std::string& keyy, std::nullptr_t)
{
return f(std::forward_as_tuple(keyx, 0), std::forward_as_tuple(keyy, 0));
}
template <class F, class T, class U>
auto compare_common_impl(rank<0>, F, const std::string&, const T&, const std::string&, const U&)
{ {
return false; return x;
} }
int compare_decay(std::nullptr_t) { return 0; }
template <class F> template <class F>
bool compare(const value& x, const value& y, F f) bool compare(const value& x, const value& y, F f)
...@@ -443,7 +431,11 @@ bool compare(const value& x, const value& y, F f) ...@@ -443,7 +431,11 @@ bool compare(const value& x, const value& y, F f)
bool result = false; bool result = false;
x.visit_value([&](auto&& a) { x.visit_value([&](auto&& a) {
y.visit_value([&](auto&& b) { y.visit_value([&](auto&& b) {
result = compare_common_impl(rank<1>{}, f, x.get_key(), a, y.get_key(), b); if constexpr(std::is_same<decltype(a), decltype(b)>{})
result = f(std::forward_as_tuple(x.get_key(), compare_decay(a)),
std::forward_as_tuple(y.get_key(), compare_decay(b)));
else
assert(false); // NOLINT
}); });
}); });
return result; return result;
...@@ -462,11 +454,16 @@ bool operator==(const value& x, const value& y) ...@@ -462,11 +454,16 @@ bool operator==(const value& x, const value& y)
return false; return false;
return compare(x, y, std::equal_to<>{}); return compare(x, y, std::equal_to<>{});
} }
bool operator!=(const value& x, const value& y) { return !(x == y); } bool operator!=(const value& x, const value& y) { return not(x == y); }
bool operator<(const value& x, const value& y) { return compare(x, y, std::less<>{}); } bool operator<(const value& x, const value& y)
bool operator<=(const value& x, const value& y) { return x == y or x < y; } {
if(x.get_type() != y.get_type())
return x.get_type() < y.get_type();
return compare(x, y, std::less<>{});
}
bool operator<=(const value& x, const value& y) { return not(x > y); }
bool operator>(const value& x, const value& y) { return y < x; } bool operator>(const value& x, const value& y) { return y < x; }
bool operator>=(const value& x, const value& y) { return x == y or x > y; } bool operator>=(const value& x, const value& y) { return not(x < y); }
void print_value(std::ostream& os, std::nullptr_t) { os << "null"; } void print_value(std::ostream& os, std::nullptr_t) { os << "null"; }
......
...@@ -68,9 +68,9 @@ struct nop ...@@ -68,9 +68,9 @@ struct nop
{ {
static std::string as_string() { return ""; } static std::string as_string() { return ""; }
template <class T> template <class T>
static decltype(auto) call(T&& x) static auto call(T&& x)
{ {
return x; return static_cast<T&&>(x);
} }
}; };
...@@ -113,6 +113,33 @@ inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v. ...@@ -113,6 +113,33 @@ inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.
return s; return s;
} }
template <class T>
const T& get_value(const T& x)
{
return x;
}
template <class T, class Operator = nop>
struct lhs_expression;
template <class T>
lhs_expression<T> make_lhs_expression(T&& lhs);
template <class T, class Operator>
lhs_expression<T, Operator> make_lhs_expression(T&& lhs, Operator);
// NOLINTNEXTLINE
#define TEST_EXPR_BINARY_OPERATOR(op, name) \
template <class V> \
auto operator op(const V& rhs2) const \
{ \
return make_expression(*this, rhs2, name{}); /* NOLINT */ \
}
// NOLINTNEXTLINE
#define TEST_EXPR_UNARY_OPERATOR(op, name) \
auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ }
template <class T, class U, class Operator> template <class T, class U, class Operator>
struct expression struct expression
{ {
...@@ -125,7 +152,12 @@ struct expression ...@@ -125,7 +152,12 @@ struct expression
return s; return s;
} }
decltype(auto) value() const { return Operator::call(lhs, rhs); }; friend decltype(auto) get_value(const expression& e) { return e.value(); }
decltype(auto) value() const { return Operator::call(get_value(lhs), get_value(rhs)); };
TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR)
TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR)
}; };
// TODO: Remove rvalue references // TODO: Remove rvalue references
...@@ -135,9 +167,6 @@ expression<T, U, Operator> make_expression(T&& rhs, U&& lhs, Operator) ...@@ -135,9 +167,6 @@ expression<T, U, Operator> make_expression(T&& rhs, U&& lhs, Operator)
return {std::forward<T>(rhs), std::forward<U>(lhs)}; return {std::forward<T>(rhs), std::forward<U>(lhs)};
} }
template <class T, class Operator = nop>
struct lhs_expression;
// TODO: Remove rvalue reference // TODO: Remove rvalue reference
template <class T> template <class T>
lhs_expression<T> make_lhs_expression(T&& lhs) lhs_expression<T> make_lhs_expression(T&& lhs)
...@@ -166,22 +195,12 @@ struct lhs_expression ...@@ -166,22 +195,12 @@ struct lhs_expression
return s; return s;
} }
decltype(auto) value() const { return Operator::call(lhs); } friend decltype(auto) get_value(const lhs_expression& e) { return e.value(); }
// NOLINTNEXTLINE
#define TEST_LHS_BINARY_OPERATOR(op, name) \
template <class U> \
auto operator op(const U& rhs) const \
{ \
return make_expression(lhs, rhs, name{}); /* NOLINT */ \
}
TEST_FOREACH_BINARY_OPERATORS(TEST_LHS_BINARY_OPERATOR) decltype(auto) value() const { return Operator::call(get_value(lhs)); }
// NOLINTNEXTLINE TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR)
#define TEST_LHS_UNARY_OPERATOR(op, name) \ TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR)
auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ }
TEST_FOREACH_UNARY_OPERATORS(TEST_LHS_UNARY_OPERATOR)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define TEST_LHS_REOPERATOR(op) \ #define TEST_LHS_REOPERATOR(op) \
...@@ -223,6 +242,13 @@ auto make_predicate(const std::string& msg, F f) ...@@ -223,6 +242,13 @@ auto make_predicate(const std::string& msg, F f)
return make_lhs_expression(predicate<F>{msg, f}, function{}); return make_lhs_expression(predicate<F>{msg, f}, function{});
} }
inline std::string as_string(bool x)
{
if(x)
return "true";
return "false";
}
template <class T> template <class T>
std::string as_string(const T& x) std::string as_string(const T& x)
{ {
...@@ -627,18 +653,21 @@ inline void run(int argc, const char* argv[]) ...@@ -627,18 +653,21 @@ inline void run(int argc, const char* argv[])
} // namespace test } // namespace test
// NOLINTNEXTLINE
#define TEST_CAPTURE(...) test::capture{}->*__VA_ARGS__
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define CHECK(...) \ #define CHECK(...) \
test::failed( \ test::failed( \
test::capture{}->*__VA_ARGS__, #__VA_ARGS__, __PRETTY_FUNCTION__, __FILE__, __LINE__, [] { \ test::capture{}->*__VA_ARGS__, #__VA_ARGS__, __PRETTY_FUNCTION__, __FILE__, __LINE__, [] { \
}) })
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define EXPECT(...) \ #define EXPECT(...) \
test::failed(test::capture{}->*__VA_ARGS__, \ test::failed(TEST_CAPTURE(__VA_ARGS__), \
#__VA_ARGS__, \ #__VA_ARGS__, \
__PRETTY_FUNCTION__, \ __PRETTY_FUNCTION__, \
__FILE__, \ __FILE__, \
__LINE__, \ __LINE__, \
&test::fail) &test::fail)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define STATUS(...) EXPECT((__VA_ARGS__) == 0) #define STATUS(...) EXPECT((__VA_ARGS__) == 0)
......
...@@ -540,6 +540,14 @@ TEST_CASE(value_construct_object_string_mixed_value) ...@@ -540,6 +540,14 @@ TEST_CASE(value_construct_object_string_mixed_value)
EXPECT(v.at("two").get_int64() == 2); EXPECT(v.at("two").get_int64() == 2);
} }
template <class Expression>
auto compare_predicate(const Expression& e)
{
bool result = e.value();
return test::make_predicate(test::as_string(e) + " => " + test::as_string(result),
[=] { return result; });
}
TEST_CASE(value_compare) TEST_CASE(value_compare)
{ {
EXPECT(migraphx::value(1) == migraphx::value(1)); EXPECT(migraphx::value(1) == migraphx::value(1));
...@@ -553,6 +561,46 @@ TEST_CASE(value_compare) ...@@ -553,6 +561,46 @@ TEST_CASE(value_compare)
EXPECT(migraphx::value(2) > migraphx::value(1)); EXPECT(migraphx::value(2) > migraphx::value(1));
EXPECT(migraphx::value(2) >= migraphx::value(1)); EXPECT(migraphx::value(2) >= migraphx::value(1));
EXPECT(migraphx::value(1) >= migraphx::value(1)); EXPECT(migraphx::value(1) >= migraphx::value(1));
EXPECT(migraphx::value(1) != migraphx::value("1"));
EXPECT(migraphx::value(1) != migraphx::value());
}
// NOLINTNEXTLINE
#define MIGRAPHX_VALUE_TEST_COMPARE(...) compare_predicate(TEST_CAPTURE(__VA_ARGS__))
// NOLINTNEXTLINE
#define EXPECT_TOTALLY_ORDERED_IMPL(_, x, y) \
EXPECT(_(x <= y) or _(x >= y)); \
EXPECT(_(x < y) or _(x > y) or _(x == y)); \
EXPECT((_(x < y) or _(x > y)) == _(x != y)); \
EXPECT(_(x < y) == _(y > x)); \
EXPECT(_(x <= y) == _(y >= x)); \
EXPECT(_(x < y) != _(x >= y)); \
EXPECT(_(x > y) != _(x <= y)); \
EXPECT(_(x == y) != _(x != y))
// NOLINTNEXTLINE
#define EXPECT_TOTALLY_ORDERED(x, y) \
EXPECT_TOTALLY_ORDERED_IMPL(MIGRAPHX_VALUE_TEST_COMPARE, x, y); \
EXPECT_TOTALLY_ORDERED_IMPL(MIGRAPHX_VALUE_TEST_COMPARE, y, x)
// NOLINTNEXTLINE(readability-function-size)
TEST_CASE(value_compare_ordered)
{
EXPECT_TOTALLY_ORDERED(migraphx::value(), migraphx::value());
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value(1));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value(2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", 1));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", "2"));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", "2"));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{1}), migraphx::value(std::uint64_t{1}));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{1}), migraphx::value(std::uint64_t{2}));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{2}), migraphx::value(std::uint64_t{1}));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value("1"));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value());
} }
TEST_CASE(value_to_from_string) TEST_CASE(value_to_from_string)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment