Unverified Commit c73c0dae authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Improve handling of string literals in value class (#1141)

* Handle string literal in construction
* Improve get_default with vector
parent 251cdd74
...@@ -178,6 +178,7 @@ struct value ...@@ -178,6 +178,7 @@ struct value
value(std::nullptr_t); value(std::nullptr_t);
value(const char* i); value(const char* i);
value(const std::string& pkey, const char* i);
#define MIGRAPHX_VALUE_GENERATE_DECL_METHODS(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_DECL_METHODS(vt, cpp_type) \
value(cpp_type i); \ value(cpp_type i); \
...@@ -188,6 +189,12 @@ struct value ...@@ -188,6 +189,12 @@ struct value
const cpp_type* if_##vt() const; const cpp_type* if_##vt() const;
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DECL_METHODS) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DECL_METHODS)
template <class T>
using literal_to_string = std::conditional_t<(std::is_convertible<T, const char*>{} and
std::is_convertible<T, std::string>{}),
std::string,
T>;
template <class T> template <class T>
using pick_numeric = std::conditional_t< using pick_numeric = std::conditional_t<
std::is_floating_point<T>{}, std::is_floating_point<T>{},
...@@ -246,6 +253,7 @@ struct value ...@@ -246,6 +253,7 @@ struct value
return *this = from_values(rhs); // NOLINT return *this = from_values(rhs); // NOLINT
} }
value& operator=(const char* c);
value& operator=(std::nullptr_t); value& operator=(std::nullptr_t);
value& operator=(const std::initializer_list<value>& i); value& operator=(const std::initializer_list<value>& i);
...@@ -370,11 +378,11 @@ struct value ...@@ -370,11 +378,11 @@ struct value
} }
template <class To> template <class To>
To value_or(const To& default_value) const literal_to_string<To> value_or(const To& default_value) const
{ {
if(this->is_null()) if(this->is_null())
return default_value; return default_value;
return to<To>(); return to<literal_to_string<To>>();
} }
template <class To> template <class To>
...@@ -390,12 +398,12 @@ struct value ...@@ -390,12 +398,12 @@ struct value
} }
template <class To> template <class To>
To get(const std::string& pkey, const To& default_value) const literal_to_string<To> get(const std::string& pkey, const To& default_value) const
{ {
const auto* v = find(pkey); const auto* v = find(pkey);
if(v == this->end()) if(v == this->end())
return default_value; return default_value;
return v->to<To>(); return v->to<literal_to_string<To>>();
} }
template <class To> template <class To>
...@@ -408,10 +416,11 @@ struct value ...@@ -408,10 +416,11 @@ struct value
} }
template <class To> template <class To>
std::vector<To> get(const std::string& pkey, std::vector<literal_to_string<To>> get(const std::string& pkey,
const std::initializer_list<To>& default_value) const const std::initializer_list<To>& default_value) const
{ {
return get<std::vector<To>>(pkey, default_value); return get(pkey,
std::vector<literal_to_string<To>>{default_value.begin(), default_value.end()});
} }
friend bool operator==(const value& x, const value& y); friend bool operator==(const value& x, const value& y);
......
...@@ -213,7 +213,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -213,7 +213,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape>(m, "shape")
.def(py::init([](py::kwargs kwargs) { .def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs); auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", std::string{"float"})); auto t = migraphx::shape::parse_type(v.get("type", "float"));
auto lens = v.get<std::size_t>("lens", {1}); auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides")) if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>()); return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
......
...@@ -138,6 +138,7 @@ value::value(const std::string& pkey, const value& rhs) ...@@ -138,6 +138,7 @@ value::value(const std::string& pkey, const value& rhs)
{ {
} }
value::value(const std::string& pkey, const char* i) : value(pkey, std::string(i)) {}
value::value(const char* i) : value(std::string(i)) {} value::value(const char* i) : value(std::string(i)) {}
#define MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS(vt, cpp_type) \
...@@ -161,6 +162,12 @@ value::value(const char* i) : value(std::string(i)) {} ...@@ -161,6 +162,12 @@ value::value(const char* i) : value(std::string(i)) {}
const cpp_type* value::if_##vt() const { return x ? x->if_##vt() : nullptr; } const cpp_type* value::if_##vt() const { return x ? x->if_##vt() : nullptr; }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS)
value& value::operator=(const char* c)
{
*this = std::string{c};
return *this;
}
value& value::operator=(std::nullptr_t) value& value::operator=(std::nullptr_t)
{ {
x = nullptr; x = nullptr;
......
...@@ -57,6 +57,15 @@ TEST_CASE(value_construct_string) ...@@ -57,6 +57,15 @@ TEST_CASE(value_construct_string)
EXPECT(v.get_key().empty()); EXPECT(v.get_key().empty());
} }
TEST_CASE(value_construct_key_string_literal_pair)
{
// Use parens instead {} to construct to test the key-pair constructor
migraphx::value v("key", "one");
EXPECT(v.is_string());
EXPECT(v.get_string() == "one");
EXPECT(v.get_key() == "key");
}
TEST_CASE(value_construct_float) TEST_CASE(value_construct_float)
{ {
migraphx::value v = 1.0; migraphx::value v = 1.0;
...@@ -167,6 +176,15 @@ TEST_CASE(value_copy_assign_keyless) ...@@ -167,6 +176,15 @@ TEST_CASE(value_copy_assign_keyless)
EXPECT(v1.without_key() == v2.without_key()); EXPECT(v1.without_key() == v2.without_key());
} }
TEST_CASE(value_assign_key_string_literal_pair)
{
migraphx::value v = migraphx::value::object{};
v["key"] = "one";
EXPECT(v["key"].is_string());
EXPECT(v["key"].get_string() == "one");
EXPECT(v["key"].get_key() == "key");
}
TEST_CASE(value_construct_array) TEST_CASE(value_construct_array)
{ {
migraphx::value v = {1, 2, 3}; migraphx::value v = {1, 2, 3};
...@@ -835,4 +853,38 @@ TEST_CASE(value_or_null) ...@@ -835,4 +853,38 @@ TEST_CASE(value_or_null)
EXPECT(v.value_or(3) == 3); EXPECT(v.value_or(3) == 3);
} }
TEST_CASE(value_get_default)
{
migraphx::value v = {{"key", 1}};
EXPECT(v.get("key", 3) == 1);
EXPECT(v.get("missing", 3) == 3);
}
TEST_CASE(value_get_default_vector)
{
std::vector<int> ints = {1, 2, 3};
std::vector<int> fallback = {-1};
migraphx::value v = {{"key", ints}};
EXPECT(v.get("key", fallback) == ints);
EXPECT(v.get("missing", fallback) == fallback);
EXPECT(v.get("missing", {-1}) == fallback);
}
TEST_CASE(value_get_default_string_literal)
{
migraphx::value v = {{"key", "hello"}};
EXPECT(v.get("key", "none") == "hello");
EXPECT(v.get("missing", "none") == "none");
}
TEST_CASE(value_get_default_string_literal_vector)
{
std::vector<std::string> strings = {"1", "2", "3"};
std::vector<std::string> fallback = {"none"};
migraphx::value v = {{"key", strings}};
EXPECT(v.get("key", fallback) == strings);
EXPECT(v.get("missing", fallback) == fallback);
EXPECT(v.get("missing", {"none"}) == fallback);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment