Unverified Commit c73c0dae authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Improve handling of string literals in value class (#1141)

* Handle string literal in construction
* Improve get_default with vector
parent 251cdd74
......@@ -178,6 +178,7 @@ struct value
value(std::nullptr_t);
value(const char* i);
value(const std::string& pkey, const char* i);
#define MIGRAPHX_VALUE_GENERATE_DECL_METHODS(vt, cpp_type) \
value(cpp_type i); \
......@@ -188,6 +189,12 @@ struct value
const cpp_type* if_##vt() const;
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DECL_METHODS)
template <class T>
using literal_to_string = std::conditional_t<(std::is_convertible<T, const char*>{} and
std::is_convertible<T, std::string>{}),
std::string,
T>;
template <class T>
using pick_numeric = std::conditional_t<
std::is_floating_point<T>{},
......@@ -246,6 +253,7 @@ struct value
return *this = from_values(rhs); // NOLINT
}
value& operator=(const char* c);
value& operator=(std::nullptr_t);
value& operator=(const std::initializer_list<value>& i);
......@@ -370,11 +378,11 @@ struct value
}
template <class To>
To value_or(const To& default_value) const
literal_to_string<To> value_or(const To& default_value) const
{
if(this->is_null())
return default_value;
return to<To>();
return to<literal_to_string<To>>();
}
template <class To>
......@@ -390,12 +398,12 @@ struct value
}
template <class To>
To get(const std::string& pkey, const To& default_value) const
literal_to_string<To> get(const std::string& pkey, const To& default_value) const
{
const auto* v = find(pkey);
if(v == this->end())
return default_value;
return v->to<To>();
return v->to<literal_to_string<To>>();
}
template <class To>
......@@ -408,10 +416,11 @@ struct value
}
template <class To>
std::vector<To> get(const std::string& pkey,
std::vector<literal_to_string<To>> get(const std::string& pkey,
const std::initializer_list<To>& default_value) const
{
return get<std::vector<To>>(pkey, default_value);
return get(pkey,
std::vector<literal_to_string<To>>{default_value.begin(), default_value.end()});
}
friend bool operator==(const value& x, const value& y);
......
......@@ -213,7 +213,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::shape>(m, "shape")
.def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", std::string{"float"}));
auto t = migraphx::shape::parse_type(v.get("type", "float"));
auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
......
......@@ -138,6 +138,7 @@ value::value(const std::string& pkey, const value& rhs)
{
}
value::value(const std::string& pkey, const char* i) : value(pkey, std::string(i)) {}
value::value(const char* i) : value(std::string(i)) {}
#define MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS(vt, cpp_type) \
......@@ -161,6 +162,12 @@ value::value(const char* i) : value(std::string(i)) {}
const cpp_type* value::if_##vt() const { return x ? x->if_##vt() : nullptr; }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS)
value& value::operator=(const char* c)
{
*this = std::string{c};
return *this;
}
value& value::operator=(std::nullptr_t)
{
x = nullptr;
......
......@@ -57,6 +57,15 @@ TEST_CASE(value_construct_string)
EXPECT(v.get_key().empty());
}
TEST_CASE(value_construct_key_string_literal_pair)
{
// Use parens instead {} to construct to test the key-pair constructor
migraphx::value v("key", "one");
EXPECT(v.is_string());
EXPECT(v.get_string() == "one");
EXPECT(v.get_key() == "key");
}
TEST_CASE(value_construct_float)
{
migraphx::value v = 1.0;
......@@ -167,6 +176,15 @@ TEST_CASE(value_copy_assign_keyless)
EXPECT(v1.without_key() == v2.without_key());
}
TEST_CASE(value_assign_key_string_literal_pair)
{
migraphx::value v = migraphx::value::object{};
v["key"] = "one";
EXPECT(v["key"].is_string());
EXPECT(v["key"].get_string() == "one");
EXPECT(v["key"].get_key() == "key");
}
TEST_CASE(value_construct_array)
{
migraphx::value v = {1, 2, 3};
......@@ -835,4 +853,38 @@ TEST_CASE(value_or_null)
EXPECT(v.value_or(3) == 3);
}
TEST_CASE(value_get_default)
{
migraphx::value v = {{"key", 1}};
EXPECT(v.get("key", 3) == 1);
EXPECT(v.get("missing", 3) == 3);
}
TEST_CASE(value_get_default_vector)
{
std::vector<int> ints = {1, 2, 3};
std::vector<int> fallback = {-1};
migraphx::value v = {{"key", ints}};
EXPECT(v.get("key", fallback) == ints);
EXPECT(v.get("missing", fallback) == fallback);
EXPECT(v.get("missing", {-1}) == fallback);
}
TEST_CASE(value_get_default_string_literal)
{
migraphx::value v = {{"key", "hello"}};
EXPECT(v.get("key", "none") == "hello");
EXPECT(v.get("missing", "none") == "none");
}
TEST_CASE(value_get_default_string_literal_vector)
{
std::vector<std::string> strings = {"1", "2", "3"};
std::vector<std::string> fallback = {"none"};
migraphx::value v = {{"key", strings}};
EXPECT(v.get("key", fallback) == strings);
EXPECT(v.get("missing", fallback) == fallback);
EXPECT(v.get("missing", {"none"}) == fallback);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment