Unverified Commit f7befe50 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Cpu fusions using post_ops (#781)



* Add eliminate_data_type pass

* Formatting

* Auto convert quant ops

* Formatting

* Flip the order of decompose

* Compute max size differently

* Formatting

* Clamp values in convert

* Formatting

* Fix loss of precision in reduce

* Formatting

* Fix bugs in reduction

* Fix accumulator type in reference softmax implementation

* Formatting

* Update convert test

* Remove unused variables

* Remove unnecessary quant_dot check

* Formatting

* Add tests

* Formatting

* Remove unused code

* Remove duplicate ops

* Remove blaze dependency

* Use set since shape::type_t is no hashable on gcc 5

* Formatting

* Add dnnl binary op

* Formatting

* Add binary and eltwise

* Formatting

* Add softmax

* Formatting

* Remove unused operators

* Add missing files

* Formatting

* Add lrn

* Formatting

* Add deconvolution

* Formatting

* Change allocate default

* Add reorder

* Formatting

* Add reductions

* Formatting

* Sort lines

* Change literals in another loop

* Add pow operator

* Formatting

* Add pow operator

* Formatting

* Make sure shapes are packed

* Allow broadcasted inputs

* Remove unused operators

* Simplify functions

* Remove softmax

* Add sub and erf functions

* Formatting

* Fix bug

* Formatting

* Improve parallism

* Formatting

* Allow multiple batch dimensions

* Formatting

* Move literal transforms out of lowering

* Formatting

* Add gather operator

* Sort lines

* Add early exit for carry

* Formatting

* Add missing concat

* Rename macro

* Fix deep nesting

* Formatting

* Fix cppcheck issues

* Remov else

* Move attribute to typedef

* Formatting

* Disable maybe-uninitialized warning since its broken on gcc

* Add constexpr default constructor

* Formatting

* Fix compiler warnings

* Fix adjust_allocation test

* Add layernorm matcher

* Add gelu_erf matcher

* Formatting

* Add gelu_tanh matcher

* Formatting

* Remove match namespace

* Formatting

* Use matcher instead of string

* Formatting

* Add fusions

* Formatting

* Add post op field

* Formatting

* Make post_ops serializable

* Formatting

* Add eltwise fusions

* Formatting

* Fix null conversions

* Formatting

* Add fuse_ops source files

* Formatting

* Set binary post op index correctly

* Formatting

* Fix serialization bugs

* Check if used once

* Formatting

* Fix error in get_primitive_attr

* Formatting

* Add compile function

* Formatting

* Limit fusions

* Formatting

* Disable with env variable instead of using compile arg

* Formatting

* Fix implicit conversion to bool

* Declar on seperate lines

* Formatting

* Fix cppcheck issues

* Fix ICE in pack_join

* Formatting

* Use const ref

* Make enum hashable

* Formatting

* Add explicit this

* Fix merge issues

* Fix dangling ref

* Formatting

* Add test for compile

* Formatting

* Add more value tests

* Formatting
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 78eaf2b8
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ASSERT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ASSERT_HPP
#include <migraphx/config.hpp>
#include <cstdlib>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class F>
auto abort_on_throw(F f) -> decltype(f())
{
try
{
return f();
}
catch(const std::exception& e)
{
std::cerr << e.what() << std::endl;
std::abort();
}
catch(...)
{
std::cerr << "Unknown exception" << std::endl;
std::abort();
}
}
#ifdef NDEBUG
#define MIGRAPHX_ASSERT_NO_THROW(...) __VA_ARGS__
#else
#define MIGRAPHX_ASSERT_NO_THROW(...) \
migraphx::abort_on_throw([&]() -> decltype(__VA_ARGS__) { return __VA_ARGS__; })
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ASSERT_HPP
...@@ -125,24 +125,6 @@ auto fix(F f) ...@@ -125,24 +125,6 @@ auto fix(F f)
return fix<void>(f); return fix<void>(f);
} }
template <class... Ts>
auto pack(Ts... xs)
{
return [=](auto f) { return f(xs...); };
}
inline auto pack_join() { return pack(); }
template <class P, class... Ps>
auto pack_join(P p, Ps... ps)
{
return [=](auto f) {
return p([&](auto... xs) {
return pack_join(ps...)([&](auto... ys) { return f(xs..., ys...); });
});
};
}
template <class F, class T> template <class F, class T>
auto fold_impl(F&&, T&& x) auto fold_impl(F&&, T&& x)
{ {
...@@ -161,6 +143,22 @@ auto fold(F f) ...@@ -161,6 +143,22 @@ auto fold(F f)
return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); }; return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
} }
template <class... Ts>
auto pack(Ts... xs)
{
return [=](auto f) { return f(xs...); };
}
inline auto pack_join() { return pack(); }
template <class... Ps>
auto pack_join(Ps... ps)
{
return fold([](auto p1, auto p2) {
return p1([=](auto... xs) { return p2([=](auto... ys) { return pack(xs..., ys...); }); });
})(ps...);
}
template <class F, class Proj> template <class F, class Proj>
auto by(F f, Proj proj) auto by(F f, Proj proj)
{ {
......
...@@ -19,6 +19,7 @@ shape compute_shape(const operation& op, ...@@ -19,6 +19,7 @@ shape compute_shape(const operation& op,
const std::vector<instruction_ref>& args, const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods); const std::vector<module_ref>& mods);
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args); std::vector<shape> to_shapes(const std::vector<instruction_ref>& args);
std::vector<shape> try_compute_shape(const operation& op, const std::vector<shape>& inputs);
struct instruction struct instruction
{ {
......
...@@ -183,6 +183,17 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p) ...@@ -183,6 +183,17 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
return {{p}}; return {{p}};
} }
/// Create a typed-erased matcher
using any_matcher_base = basic_matcher<
function_matcher<std::function<instruction_ref(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base
{
template <class M>
any_matcher(M mm) : any_matcher_base({[=](auto& ctx, auto ins) { return mm.match(ctx, ins); }})
{
}
};
/// This macro takes care of the boilerplate for defining a matcher /// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPHX_BASIC_MATCHER(name, ...) \ #define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \ struct name##_m \
......
...@@ -124,6 +124,7 @@ struct module ...@@ -124,6 +124,7 @@ struct module
std::vector<shape> get_output_shapes() const; std::vector<shape> get_output_shapes() const;
instruction_ref validate() const; instruction_ref validate() const;
instruction_ref find_dangling_reference() const;
void finalize(context& ctx); void finalize(context& ctx);
......
...@@ -35,6 +35,7 @@ struct as_shape ...@@ -35,6 +35,7 @@ struct as_shape
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -66,6 +66,7 @@ struct broadcast ...@@ -66,6 +66,7 @@ struct broadcast
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -50,6 +50,7 @@ struct flatten ...@@ -50,6 +50,7 @@ struct flatten
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -23,6 +23,7 @@ struct identity ...@@ -23,6 +23,7 @@ struct identity
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -36,6 +36,7 @@ struct load ...@@ -36,6 +36,7 @@ struct load
MIGRAPHX_THROW("Load access is out of bounds"); MIGRAPHX_THROW("Load access is out of bounds");
return {s, args[0].data() + offset}; return {s, args[0].data() + offset};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op) friend std::ostream& operator<<(std::ostream& os, const load& op)
......
...@@ -68,6 +68,7 @@ struct multibroadcast ...@@ -68,6 +68,7 @@ struct multibroadcast
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -71,6 +71,7 @@ struct reshape ...@@ -71,6 +71,7 @@ struct reshape
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -39,6 +39,7 @@ struct scalar ...@@ -39,6 +39,7 @@ struct scalar
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -77,6 +77,7 @@ struct squeeze ...@@ -77,6 +77,7 @@ struct squeeze
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -63,6 +63,7 @@ struct transpose ...@@ -63,6 +63,7 @@ struct transpose
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -70,6 +70,7 @@ struct unsqueeze ...@@ -70,6 +70,7 @@ struct unsqueeze
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -341,6 +341,29 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -341,6 +341,29 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {}; return {};
} }
template <class T>
auto compile_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.compile(auto_any_cast(ctx), output_shape, input))
{
return x.compile(auto_any_cast(ctx), output_shape, input);
}
template <class T>
value compile_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
return value::object{};
}
template <class T>
value compile_op(const T& x,
context& ctx,
const shape& output_shape,
const std::vector<shape>& input)
{
return compile_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T> template <class T>
value attributes_op(const T&) value attributes_op(const T&)
{ {
...@@ -361,6 +384,12 @@ void from_value_op(T& x, const value& v) ...@@ -361,6 +384,12 @@ void from_value_op(T& x, const value& v)
return migraphx::from_value(v, x); return migraphx::from_value(v, x);
} }
template <class T>
bool is_borrowed_op(const T&)
{
return false;
}
} // namespace detail } // namespace detail
/* /*
...@@ -372,7 +401,9 @@ void from_value_op(T& x, const value& v) ...@@ -372,7 +401,9 @@ void from_value_op(T& x, const value& v)
* bool is_context_free() const; * bool is_context_free() const;
* bool need_normalization() const; * bool need_normalization() const;
* bool has_finalize() const; * bool has_finalize() const;
* bool is_borrowed() const;
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const; * std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* value compile(context& ctx,const shape& output,const std::vector<shape>& input) ;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ; * void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const; * shape compute_shape(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>& * shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
...@@ -475,12 +506,24 @@ struct operation ...@@ -475,12 +506,24 @@ struct operation
return (*this).private_detail_te_get_handle().has_finalize(); return (*this).private_detail_te_get_handle().has_finalize();
} }
bool is_borrowed() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_borrowed();
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const std::ptrdiff_t output_alias(const std::vector<shape>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().output_alias(input); return (*this).private_detail_te_get_handle().output_alias(input);
} }
value compile(context& ctx, const shape& output, const std::vector<shape>& input)
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compile(ctx, output, input);
}
void finalize(context& ctx, const shape& output, const std::vector<shape>& input) void finalize(context& ctx, const shape& output, const std::vector<shape>& input)
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -569,7 +612,10 @@ struct operation ...@@ -569,7 +612,10 @@ struct operation
virtual bool is_context_free() const = 0; virtual bool is_context_free() const = 0;
virtual bool need_normalization() const = 0; virtual bool need_normalization() const = 0;
virtual bool has_finalize() const = 0; virtual bool has_finalize() const = 0;
virtual bool is_borrowed() const = 0;
virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0; virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual value
compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual void virtual void
finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0; finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0; virtual shape compute_shape(const std::vector<shape>& input) const = 0;
...@@ -630,6 +676,19 @@ struct operation ...@@ -630,6 +676,19 @@ struct operation
return detail::has_finalize_op(private_detail_te_self); return detail::has_finalize_op(private_detail_te_self);
} }
template <class T>
static auto private_detail_te_default_is_borrowed(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.is_borrowed())
{
return private_detail_te_self.is_borrowed();
}
template <class T>
static bool private_detail_te_default_is_borrowed(float, T&& private_detail_te_self)
{
return detail::is_borrowed_op(private_detail_te_self);
}
template <class T> template <class T>
static auto private_detail_te_default_output_alias(char, static auto private_detail_te_default_output_alias(char,
T&& private_detail_te_self, T&& private_detail_te_self,
...@@ -647,6 +706,27 @@ struct operation ...@@ -647,6 +706,27 @@ struct operation
return detail::output_alias_op(private_detail_te_self, input); return detail::output_alias_op(private_detail_te_self, input);
} }
template <class T>
static auto private_detail_te_default_compile(char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
-> decltype(private_detail_te_self.compile(ctx, output, input))
{
return private_detail_te_self.compile(ctx, output, input);
}
template <class T>
static value private_detail_te_default_compile(float,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
{
return detail::compile_op(private_detail_te_self, ctx, output, input);
}
template <class T> template <class T>
static auto private_detail_te_default_finalize(char, static auto private_detail_te_default_finalize(char,
T&& private_detail_te_self, T&& private_detail_te_self,
...@@ -858,12 +938,25 @@ struct operation ...@@ -858,12 +938,25 @@ struct operation
return private_detail_te_default_has_finalize(char(0), private_detail_te_value); return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
} }
bool is_borrowed() const override
{
return private_detail_te_default_is_borrowed(char(0), private_detail_te_value);
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const override std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
{ {
return private_detail_te_default_output_alias(char(0), private_detail_te_value, input); return private_detail_te_default_output_alias(char(0), private_detail_te_value, input);
} }
value compile(context& ctx, const shape& output, const std::vector<shape>& input) override
{
return private_detail_te_default_compile(
char(0), private_detail_te_value, ctx, output, input);
}
void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override
{ {
...@@ -1010,6 +1103,24 @@ inline const ValueType& any_cast(const operation& x) ...@@ -1010,6 +1103,24 @@ inline const ValueType& any_cast(const operation& x)
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); } inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
return op.compile(ctx, output_shape, input);
}
template <class Context>
inline value
compile(operation& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
dependent_type<context, Context> ctx2 = std::ref(ctx);
return compile(op, ctx2, output_shape, input);
}
template <class T, class Context>
inline auto compile(T& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(op.compile(ctx, ctx, output_shape, input))
{
return op.compile(ctx, ctx, output_shape, input);
}
inline shape compute_shape(const operation& op, const std::vector<shape>& inputs) inline shape compute_shape(const operation& op, const std::vector<shape>& inputs)
{ {
return op.compute_shape(inputs); return op.compute_shape(inputs);
......
...@@ -102,6 +102,29 @@ void reflect_each(T& x, F f) ...@@ -102,6 +102,29 @@ void reflect_each(T& x, F f)
}); });
} }
template <class T>
struct reflect_equality
{
friend bool operator==(const T& x, const T& y) { return reflect_tie(x) == reflect_tie(y); }
friend bool operator!=(const T& x, const T& y) { return !(x == y); }
};
template <class T>
struct reflect_stream
{
template <class Stream>
friend Stream& operator<<(Stream& os, const T& x)
{
char d = '{';
reflect_each(x, [&](const auto& y, const auto& name) {
os << d << name << "=" << y;
d = ',';
});
os << "}";
return os;
}
};
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
File mode changed from 100644 to 100755
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <algorithm> #include <algorithm>
#include <cassert>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <type_traits> #include <type_traits>
...@@ -58,8 +59,6 @@ struct value_converter<std::string> ...@@ -58,8 +59,6 @@ struct value_converter<std::string>
{ {
static const std::string& apply(const std::string& x) { return x; } static const std::string& apply(const std::string& x) { return x; }
static std::string apply(const std::nullptr_t&) { return "null"; }
template <class From> template <class From>
static auto apply(const From& x) static auto apply(const From& x)
-> decltype(std::declval<std::stringstream&>() << x, std::string()) -> decltype(std::declval<std::stringstream&>() << x, std::string())
...@@ -83,20 +82,28 @@ struct value_converter<std::pair<T, U>> ...@@ -83,20 +82,28 @@ struct value_converter<std::pair<T, U>>
} }
}; };
template <class To, class From>
To try_convert_value(const From& x);
namespace detail { namespace detail {
template <class To, class Key, class From> template <class To, class Key, class From>
auto try_convert_value_impl(rank<2>, const std::pair<Key, From>& x) To try_convert_value_impl(rank<1>, const std::pair<Key, From>& x)
-> decltype(value_converter<To>::apply(x.second))
{ {
return value_converter<To>::apply(x.second); return try_convert_value<To>(x.second);
} }
template <class To, class From> template <class To, class From>
auto try_convert_value_impl(rank<1>, const From& x) -> decltype(value_converter<To>::apply(x)) auto try_convert_value_impl(rank<2>, const From& x) -> decltype(value_converter<To>::apply(x))
{ {
return value_converter<To>::apply(x); return value_converter<To>::apply(x);
} }
template <class To, MIGRAPHX_REQUIRES(not std::is_same<To, std::nullptr_t>{})>
To try_convert_value_impl(rank<3>, std::nullptr_t)
{
MIGRAPHX_THROW("Incompatible values: null -> " + get_type_name<To>());
}
template <class To, class From> template <class To, class From>
To try_convert_value_impl(rank<0>, const From& x) To try_convert_value_impl(rank<0>, const From& x)
{ {
...@@ -107,7 +114,7 @@ To try_convert_value_impl(rank<0>, const From& x) ...@@ -107,7 +114,7 @@ To try_convert_value_impl(rank<0>, const From& x)
template <class To, class From> template <class To, class From>
To try_convert_value(const From& x) To try_convert_value(const From& x)
{ {
return detail::try_convert_value_impl<To>(rank<2>{}, x); return detail::try_convert_value_impl<To>(rank<3>{}, x);
} }
struct value struct value
...@@ -309,7 +316,11 @@ struct value ...@@ -309,7 +316,11 @@ struct value
{ {
case null_type: case null_type:
{ {
v(std::nullptr_t{}); std::nullptr_t null{};
if(this->key.empty())
v(null);
else
v(std::make_pair(this->get_key(), std::ref(null)));
return; return;
} }
#define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \
...@@ -328,6 +339,31 @@ struct value ...@@ -328,6 +339,31 @@ struct value
MIGRAPHX_THROW("Unknown type"); MIGRAPHX_THROW("Unknown type");
} }
// Visit value without key
template <class Visitor>
void visit_value(Visitor v) const
{
switch(this->get_type())
{
case null_type:
{
std::nullptr_t null{};
v(null);
return;
}
#define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \
case vt##_type: \
{ \
v(this->get_##vt()); \
return; \
}
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE)
MIGRAPHX_VALUE_GENERATE_CASE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, )
}
MIGRAPHX_THROW("Unknown type");
}
template <class To> template <class To>
To to() const To to() const
{ {
...@@ -336,6 +372,14 @@ struct value ...@@ -336,6 +372,14 @@ struct value
return result; return result;
} }
template <class To>
To value_or(const To& default_value) const
{
if(this->is_null())
return default_value;
return to<To>();
}
template <class To> template <class To>
std::vector<To> to_vector() const std::vector<To> to_vector() const
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment