"docs/vscode:/vscode.git/clone" did not exist on "7a065a9c56438dad282a0709df1b45bf2bd3bd6a"
Unverified Commit 11212a94 authored by pfeatherstone's avatar pfeatherstone Committed by GitHub
Browse files

[SERIALIZATION] added support for std::variant (#2362)



* [SERIALIZATION] addes support for std::variant

* [SERIALIZATION] bug fix + added tests

* support immutable types

* put an immutable type in std::variant
Co-authored-by: default avatarpf <pf@me>
parent a54507d8
......@@ -89,6 +89,7 @@
- std::complex
- std::unique_ptr
- std::shared_ptr
- std::variant (if C++17 is used)
- dlib::uint64
- dlib::int64
- float_details
......@@ -119,6 +120,7 @@
- std::complex
- std::unique_ptr
- std::shared_ptr
- std::variant (if C++17 is used)
- dlib::uint64
- dlib::int64
- float_details
......@@ -233,6 +235,9 @@
#include <limits>
#include <type_traits>
#include <utility>
#if __cplusplus >= 201703L
#include <variant>
#endif
#include "uintn.h"
#include "interfaces/enumerable.h"
#include "interfaces/map_pair.h"
......@@ -1216,6 +1221,59 @@ namespace dlib
{ throw serialization_error(e.info + "\n while deserializing object of type std::tuple"); }
}
// ----------------------------------------------------------------------------------------
#if __cplusplus >= 201703L
namespace detail
{
template<typename Variant, std::size_t I = 0>
void deserialize_variant_helper(Variant& item, std::istream& in, std::size_t index)
{
if constexpr (I < std::variant_size_v<Variant>)
{
if (I == index)
{
auto& x = item.template emplace<std::variant_alternative_t<I,Variant>>();
deserialize(bin, x);
}
else
{
deserialize_variant_helper<Variant, I+1>(item, in, index);
}
}
else
{
throw serialization_error("deserialize_variant_helper() index is out of range of variant size");
}
}
}
template<typename... Types>
void serialize(
const std::variant<Types...>& item,
std::ostream& out
)
{
serialize(item.index(), out);
std::visit([&out](const auto& x) {
serialize(x, out);
}, item);
}
template <typename... Types>
void deserialize(
std::variant<Types...>& item,
std::istream& in
)
{
std::size_t index = 0;
deserialize(index, in);
detail::deserialize_variant_helper(item, in, index);
}
#endif
// ----------------------------------------------------------------------------------------
template <typename domain, typename range, typename compare, typename alloc>
......
......@@ -414,6 +414,19 @@ namespace
return l && r ? *l == *r : l == r;
}
struct immutable_type
{
immutable_type() = default;
immutable_type(const immutable_type& other) = delete;
immutable_type& operator=(const immutable_type& other) = delete;
immutable_type(immutable_type&& other) = delete;
immutable_type& operator=(immutable_type&& other) = delete;
friend void serialize(const immutable_type& x, std::ostream& out) {}
friend void deserialize(immutable_type& x, std::istream& in) {}
bool operator==(const immutable_type& other) const {return true;}
};
struct my_custom_type
{
int a;
......@@ -433,16 +446,29 @@ namespace
std::unordered_multiset<string> o;
std::shared_ptr<string> ptr_shared1;
std::shared_ptr<string> ptr_shared2;
std::vector<std::complex<double>> p;
std::vector<std::complex<double>> p;
#if __cplusplus >= 201703L
std::variant<int,float,std::string,immutable_type> q;
#endif
bool operator==(const my_custom_type& rhs) const
{
#if __cplusplus >= 201703L
const bool cpp17_ok = q == rhs.q;
#else
const bool cpp17_ok = true;
#endif
return std::tie(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p) == std::tie(rhs.a,rhs.b,rhs.c,rhs.d,rhs.e,rhs.f,rhs.g,rhs.h,rhs.i,rhs.j,rhs.k,rhs.l,rhs.m,rhs.n,rhs.o,rhs.p)
&& cpp17_ok
&& pointers_values_equal(ptr_shared1, rhs.ptr_shared1)
&& pointers_values_equal(ptr_shared2, rhs.ptr_shared2);
}
#if __cplusplus >= 201703L
DLIB_DEFINE_DEFAULT_SERIALIZATION(my_custom_type, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, ptr_shared1, ptr_shared2, q);
#else
DLIB_DEFINE_DEFAULT_SERIALIZATION(my_custom_type, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, ptr_shared1, ptr_shared2);
#endif
};
struct my_custom_type_array
......@@ -1124,7 +1150,9 @@ namespace
dlib::rand rng(std::time(NULL));
for (int i = 0 ; i < 1024 ; i++)
t1.p.push_back(rng.get_random_gaussian());
#if __cplusplus >= 201703L
t1.q = "hello there from std::variant, welcome!";
#endif
t2.a = 2;
t2.b = 4.0;
t2.c.resize(10);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment