Unverified Commit aaac87a2 authored by pfeatherstone's avatar pfeatherstone Committed by GitHub
Browse files

[TYPE_SAFE_UNION] simplified some type traits and added for_each(). (#2475)

* [TYPE_SAFE_UNION] simplified some type traits and added for_each().

* added example serialization/deserialization using typeid().hash_code and different type_safe_union types

* in_place_tag is an empty struct. so don't pass const references, you're unecessarily passing 8 bytes around for now reason

* - added variant_size for type_safe_union
- added variant_alternative for type_safe_union
- removed for type_safe_union::for_each() and replaced with global function dlib::for_each_type()

* - made visit() a global

* use dlib::invoke explicitly

* - for_each_type is implemented using fold expression (or whatever the right term is) instead of template recursion. This method, in theory, yields better compile times. And if you're familiar with parameter packs, then the implementation is easier to read.

* - refactoring
- reordered function parameters in for_each_type()
- vtable implementation of apply_to_contents and visit() (sorry Davis for yet another change)

* add option to not zero out gradients and method to do it (#2477)

* Avoid different kinds of compiler warnings (#2481)

* Avoid different kinds of compiler warnings that started to appear when upgrading my build environment

* Avoid more compiler warnings

* Revert the overly verbose static_cast changes

* Make resize_bilinear and resize_bilinear_gradient take long long (previously just long)

* Circumvent what appears to be a bug in Visual Studio 2019's optimizer
(see: https://forum.juce.com/t/warning-in-the-lastest-vs2019/38267

)

* Fix MSVC pragma warnings with other compilers (#2483)

* Fix warning about unused zero_gradients parameter (#2487)

* Fix warning about unused zero_gradients parameter

* match signature of other methods

* cleanup
Co-authored-by: default avatarpfeatherstone <peter@me>
Co-authored-by: default avatarAdrià Arrufat <1671644+arrufat@users.noreply.github.com>
Co-authored-by: default avatarJuha Reunanen <juha.reunanen@tomaattinen.com>
Co-authored-by: default avatarDavis King <davis@dlib.net>
parent 42e08696
...@@ -80,6 +80,13 @@ namespace ...@@ -80,6 +80,13 @@ namespace
public: public:
void test_stuff() void test_stuff()
{ {
static_assert(tsu::get_type_id<float>() == 1, "bad type id");
static_assert(tsu::get_type_id<double>() == 2, "bad type id");
static_assert(tsu::get_type_id<char>() == 3, "bad type id");
static_assert(tsu::get_type_id<std::string>() == 4, "bad type id");
static_assert(tsu::get_type_id<long>() == -1, "This should be -1");
DLIB_TEST(a.is_empty() == true); DLIB_TEST(a.is_empty() == true);
DLIB_TEST(a.contains<char>() == false); DLIB_TEST(a.contains<char>() == false);
DLIB_TEST(a.contains<float>() == false); DLIB_TEST(a.contains<float>() == false);
...@@ -566,7 +573,7 @@ namespace ...@@ -566,7 +573,7 @@ namespace
DLIB_TEST(b.is_empty()); DLIB_TEST(b.is_empty());
DLIB_TEST(b.get_current_type_id() == 0); DLIB_TEST(b.get_current_type_id() == 0);
//visit can return non-void types //visit can return non-void types
auto ret = a.visit(overloaded( auto ret = visit(overloaded(
[](int) { [](int) {
return std::string("int"); return std::string("int");
}, },
...@@ -576,7 +583,7 @@ namespace ...@@ -576,7 +583,7 @@ namespace
[](const std::string&) { [](const std::string&) {
return std::string("std::string"); return std::string("std::string");
} }
)); ), a);
static_assert(std::is_same<std::string, decltype(ret)>::value, "bad return type"); static_assert(std::is_same<std::string, decltype(ret)>::value, "bad return type");
DLIB_TEST(ret == "int"); DLIB_TEST(ret == "int");
//apply_to_contents can only return void //apply_to_contents can only return void
...@@ -604,7 +611,7 @@ namespace ...@@ -604,7 +611,7 @@ namespace
tsu_b object(dlib::in_place_tag<tsu_a>{}, std::string("hello from bottom node")); tsu_b object(dlib::in_place_tag<tsu_a>{}, std::string("hello from bottom node"));
DLIB_TEST(object.contains<tsu_a>()); DLIB_TEST(object.contains<tsu_a>());
DLIB_TEST(object.get<tsu_a>().get<std::string>() == "hello from bottom node"); DLIB_TEST(object.get<tsu_a>().get<std::string>() == "hello from bottom node");
auto ret = object.visit(overloaded( auto ret = visit(overloaded(
[](int) { [](int) {
return std::string("int"); return std::string("int");
}, },
...@@ -615,7 +622,7 @@ namespace ...@@ -615,7 +622,7 @@ namespace
return std::string("std::string"); return std::string("std::string");
}, },
[](const tsu_a& item) { [](const tsu_a& item) {
return item.visit(overloaded( return visit( overloaded(
[](int) { [](int) {
return std::string("nested int"); return std::string("nested int");
}, },
...@@ -625,20 +632,19 @@ namespace ...@@ -625,20 +632,19 @@ namespace
[](std::string str) { [](std::string str) {
return str; return str;
} }
)); ), item);
} }
)); ), object);
static_assert(std::is_same<std::string, decltype(ret)>::value, "bad type"); static_assert(std::is_same<std::string, decltype(ret)>::value, "bad type");
DLIB_TEST(ret == "hello from bottom node"); DLIB_TEST(ret == "hello from bottom node");
} }
{ {
//"private" visitor //struct visitor
using tsu = type_safe_union<int,float,std::string>; using tsu = type_safe_union<int,float,std::string>;
class visitor_private struct visitor_private
{ {
private:
std::string operator()(int) std::string operator()(int)
{ {
return std::string("int"); return std::string("int");
...@@ -653,19 +659,162 @@ namespace ...@@ -653,19 +659,162 @@ namespace
{ {
return str; return str;
} }
friend tsu;
}; };
visitor_private visitor; visitor_private visitor;
tsu a = std::string("hello from private visitor"); tsu a = std::string("hello from private visitor");
auto ret = a.visit(visitor); auto ret = visit(visitor, a);
static_assert(std::is_same<std::string, decltype(ret)>::value, "bad type"); static_assert(std::is_same<std::string, decltype(ret)>::value, "bad type");
DLIB_TEST(ret == "hello from private visitor"); DLIB_TEST(ret == "hello from private visitor");
} }
} }
}; };
namespace test_for_each_1
{
/*! Local classes aren't allowed to have template member functions... !*/
using tsu = type_safe_union<int,float,std::string>;
static_assert(type_safe_union_size<tsu>::value == 3, "bad number of types");
static_assert(std::is_same<type_safe_union_alternative_t<0, tsu>, int>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<1, tsu>, float>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<2, tsu>, std::string>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<0, const tsu>, const int>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<1, const tsu>, const float>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<2, const tsu>, const std::string>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<0, volatile tsu>, volatile int>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<1, volatile tsu>, volatile float>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<2, volatile tsu>, volatile std::string>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<0, const volatile tsu>, const volatile int>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<1, const volatile tsu>, const volatile float>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<2, const volatile tsu>, const volatile std::string>::value, "bad type");
struct for_each_visitor
{
std::vector<int> type_indices;
template<typename T>
void operator()(dlib::in_place_tag<T>, const tsu& item)
{
type_indices.push_back(item.get_type_id<T>());
}
};
void test()
{
tsu a;
for_each_visitor visitor;
for_each_type(visitor, a);
DLIB_TEST(visitor.type_indices.size() == 3);
DLIB_TEST(visitor.type_indices[0] == 1);
DLIB_TEST(visitor.type_indices[1] == 2);
DLIB_TEST(visitor.type_indices[2] == 3);
}
}
namespace test_for_each_2
{
/*! Local classes aren't allowed to have template member functions... !*/
using tsu = type_safe_union<int,float,std::string>;
//for_each() that demonstrates an actual use-case.
//Instead of something simple like a target index, you might want to set the variant
//based on some specific state that is unique to one of the alternative types.
//For example, you might want to conditionally set the variant based on hashes.
struct for_each_visitor
{
for_each_visitor(int target_index_) : target_index(target_index_) {}
template<typename TagType>
void operator()(TagType tag, tsu& item)
{
if (item.get_type_id<TagType>() == target_index)
item = tsu{tag};
}
const int target_index = 0;
};
void test()
{
tsu a;
for_each_type(for_each_visitor{1}, a);
DLIB_TEST(a.contains<int>());
a.clear();
for_each_type(for_each_visitor{2}, a);
DLIB_TEST(a.contains<float>());
a.clear();
for_each_type(for_each_visitor{3}, a);
DLIB_TEST(a.contains<std::string>());
a.clear();
for_each_type(for_each_visitor{-1}, a);
DLIB_TEST(a.is_empty());
a.clear();
for_each_type(for_each_visitor{215465}, a);
DLIB_TEST(a.is_empty());
}
}
namespace test_for_each_3
{
using tsu1 = type_safe_union<int,float,std::string>;
using tsu2 = type_safe_union<std::string,long,char>;
struct serializer_typeid
{
serializer_typeid(std::ostream& out_) : out(out_) {}
template<typename T>
void operator()(const T& x)
{
dlib::serialize(typeid(T).hash_code(), out);
dlib::serialize(x, out);
}
std::ostream& out;
};
struct deserializer_typeid
{
deserializer_typeid(std::istream& in_) : in(in_)
{
dlib::deserialize(hash_code, in);
}
template<typename T, typename TSU>
void operator()(in_place_tag<T>, TSU&& me)
{
if (typeid(T).hash_code() == hash_code)
dlib::deserialize(me.template get<T>(), in);
}
std::size_t hash_code = 0;
std::istream& in;
};
void test()
{
tsu1 a;
a.get<std::string>() = "hello from tsu1";
std::stringstream out;
visit(serializer_typeid(out), a);
tsu2 b;
for_each_type(deserializer_typeid(out), b);
DLIB_TEST(b.contains<std::string>());
DLIB_TEST(b.get<std::string>() == "hello from tsu1");
}
}
class type_safe_union_tester : public tester class type_safe_union_tester : public tester
{ {
public: public:
...@@ -682,6 +831,9 @@ namespace ...@@ -682,6 +831,9 @@ namespace
{ {
test a; test a;
a.test_stuff(); a.test_stuff();
test_for_each_1::test();
test_for_each_2::test();
test_for_each_3::test();
} }
} }
} a; } a;
......
...@@ -9,9 +9,12 @@ ...@@ -9,9 +9,12 @@
#include <type_traits> #include <type_traits>
#include <functional> #include <functional>
#include "../serialize.h" #include "../serialize.h"
#include "../invoke.h"
namespace dlib namespace dlib
{ {
// ---------------------------------------------------------------------
class bad_type_safe_union_cast : public std::bad_cast class bad_type_safe_union_cast : public std::bad_cast
{ {
public: public:
...@@ -21,87 +24,88 @@ namespace dlib ...@@ -21,87 +24,88 @@ namespace dlib
} }
}; };
// ---------------------------------------------------------------------
template<typename T> template<typename T>
struct in_place_tag {}; struct in_place_tag { using type = T;};
namespace internal
{
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
template <typename T, typename... Rest>
struct is_any : std::false_type {};
template <typename T, typename First> template <typename... Types> class type_safe_union;
struct is_any<T,First> : std::is_same<T,First> {};
template <typename T, typename First, typename... Rest> template<typename Tsu>
struct is_any<T,First,Rest...> : std::integral_constant<bool, std::is_same<T,First>::value || is_any<T,Rest...>::value> {}; struct type_safe_union_size;
template<typename... Types>
struct type_safe_union_size<type_safe_union<Types...>> : std::integral_constant<size_t, sizeof...(Types)> {};
template<typename Tsu> struct type_safe_union_size<const Tsu> : type_safe_union_size<Tsu> {};
template<typename Tsu> struct type_safe_union_size<volatile Tsu> : type_safe_union_size<Tsu> {};
template<typename Tsu> struct type_safe_union_size<const volatile Tsu> : type_safe_union_size<Tsu> {};
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
namespace detail namespace detail
{ {
struct empty {}; template<size_t I, typename... Ts>
struct nth_type;
template<typename T> template<size_t I, typename T0, typename... Ts>
struct variant_type_placeholder {using type = T;}; struct nth_type<I, T0, Ts...> : nth_type<I-1, Ts...> {};
template < template<typename T0, typename... Ts>
size_t nVariantTypes, struct nth_type<0, T0, Ts...> { using type = T0; };
size_t Counter,
size_t I,
typename... VariantTypes
>
struct variant_get_type_impl {};
template <
size_t nVariantTypes,
size_t Counter,
size_t I,
typename VariantTypeFirst,
typename... VariantTypeRest
>
struct variant_get_type_impl<nVariantTypes, Counter, I, VariantTypeFirst, VariantTypeRest...>
: std::conditional<(Counter < nVariantTypes),
typename std::conditional<(I == Counter),
variant_type_placeholder<VariantTypeFirst>,
variant_get_type_impl<nVariantTypes, Counter+1, I, VariantTypeRest...>
>::type,
empty
>::type {};
} }
template <size_t I, typename... VariantTypes> template <size_t I, typename TSU>
struct variant_get_type : detail::variant_get_type_impl<sizeof...(VariantTypes), 0, I, VariantTypes...> {}; struct type_safe_union_alternative;
template <size_t I, typename... Types>
struct type_safe_union_alternative<I, type_safe_union<Types...>> : detail::nth_type<I, Types...>{};
template<size_t I, typename TSU>
using type_safe_union_alternative_t = typename type_safe_union_alternative<I, TSU>::type;
template <size_t I, typename TSU>
struct type_safe_union_alternative<I, const TSU>
{ using type = typename std::add_const<type_safe_union_alternative_t<I, TSU>>::type; };
template <size_t I, typename TSU>
struct type_safe_union_alternative<I, volatile TSU>
{ using type = typename std::add_volatile<type_safe_union_alternative_t<I, TSU>>::type; };
template <size_t I, typename TSU>
struct type_safe_union_alternative<I, const volatile TSU>
{ using type = typename std::add_cv<type_safe_union_alternative_t<I, TSU>>::type; };
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
namespace detail namespace detail
{ {
template < // ---------------------------------------------------------------------
size_t nVariantTypes,
size_t Counter, template <typename T, typename First, typename... Rest>
typename T, struct is_any : std::integral_constant<bool, is_any<T,First>::value || is_any<T,Rest...>::value> {};
typename... VariantTypes
> template <typename T, typename First>
struct variant_type_id_impl : std::integral_constant<int,-1> {}; struct is_any<T,First> : std::is_same<T,First> {};
// ---------------------------------------------------------------------
template <int nTs, typename T, typename... Ts>
struct type_safe_union_type_id_impl
: std::integral_constant<int, -1 - nTs> {};
template <int nTs, typename T, typename T0, typename... Ts>
struct type_safe_union_type_id_impl<nTs, T, T0, Ts...>
: std::integral_constant<int, std::is_same<T,T0>::value ? 1 : type_safe_union_type_id_impl<nTs, T,Ts...>::value + 1> {};
template <typename T, typename... Ts>
struct type_safe_union_type_id : type_safe_union_type_id_impl<sizeof...(Ts),T,Ts...>{};
template <typename T, typename... Ts>
struct type_safe_union_type_id<in_place_tag<T>, Ts...> : type_safe_union_type_id<T,Ts...>{};
template <
size_t nVariantTypes,
size_t Counter,
typename T,
typename VariantTypeFirst,
typename... VariantTypesRest
>
struct variant_type_id_impl<nVariantTypes,Counter,T,VariantTypeFirst,VariantTypesRest...>
: std::conditional<(Counter < nVariantTypes),
typename std::conditional<std::is_same<T,VariantTypeFirst>::value,
std::integral_constant<int,Counter+1>,
variant_type_id_impl<nVariantTypes,Counter + 1, T, VariantTypesRest...>
>::type,
std::integral_constant<int,-1>
>::type {};
}
template <typename T, typename... VariantTypes>
struct variant_type_id : detail::variant_type_id_impl<sizeof...(VariantTypes), 0, T, VariantTypes...> {};
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
} }
...@@ -119,124 +123,57 @@ namespace dlib ...@@ -119,124 +123,57 @@ namespace dlib
template <typename T> template <typename T>
static constexpr int get_type_id () static constexpr int get_type_id ()
{ {
return internal::variant_type_id<T,Types...>::value; return detail::type_safe_union_type_id<T,Types...>::value;
} }
private: private:
template<typename T> template<typename T>
struct is_valid : internal::is_any<T,Types...> {}; struct is_valid : detail::is_any<T,Types...> {};
template<typename T> template<typename T>
using is_valid_check = typename std::enable_if<is_valid<T>::value, bool>::type; using is_valid_check = typename std::enable_if<is_valid<T>::value, bool>::type;
template <size_t I> template <size_t I>
struct get_type : internal::variant_get_type<I,Types...> {}; using get_type_t = type_safe_union_alternative_t<I, type_safe_union>;
template <size_t I>
using get_type_t = typename get_type<I>::type;
using T0 = get_type_t<0>;
template<typename F, typename T>
struct return_type
{
using type = decltype(std::declval<F>()(std::declval<T>()));
};
template<typename F, typename T>
using return_type_t = typename return_type<F,T>::type;
typename std::aligned_union<0, Types...>::type mem; typename std::aligned_union<0, Types...>::type mem;
int type_identity = 0; int type_identity = 0;
template< template<
size_t I, typename F,
typename F typename TSU,
> std::size_t I
auto visit_impl(
F&&
) -> typename std::enable_if<
(I == sizeof...(Types)) &&
std::is_same<void, return_type_t<F, T0&>>::value
>::type
{
}
template<
size_t I,
typename F
> >
auto visit_impl( static void apply_to_contents_as_type(
F&& F&& f,
) -> typename std::enable_if< TSU&& me
(I == sizeof...(Types)) && )
! std::is_same<void, return_type_t<F, T0&>>::value,
return_type_t<F, T0&>
>::type
{ {
return return_type_t<F, T0&>{}; std::forward<F>(f)(me.template unchecked_get<get_type_t<I>>());
} }
template< template<
size_t I, typename F,
typename F typename TSU,
std::size_t... I
> >
auto visit_impl( static void apply_to_contents_impl(
F&& f F&& f,
) -> typename std::enable_if< TSU&& me,
(I < sizeof...(Types)), dlib::index_sequence<I...>
return_type_t<F, T0&> )
>::type
{ {
if (type_identity == (I+1)) using func_t = void(*)(F&&, TSU&&);
return std::forward<F>(f)(unchecked_get<get_type_t<I>>());
else
return visit_impl<I+1>(std::forward<F>(f));
}
template< const func_t vtable[] = {
size_t I, /*! Empty (type_identity == 0) case !*/
typename F [](F&&, TSU&&) {
> },
auto visit_impl( /*! Non-empty cases !*/
F&& &apply_to_contents_as_type<F&&,TSU&&,I>...
) const -> typename std::enable_if< };
(I == sizeof...(Types)) &&
std::is_same<void, return_type_t<F, const T0&>>::value
>::type
{
}
template< return vtable[me.get_current_type_id()](std::forward<F>(f), std::forward<TSU>(me));
size_t I,
typename F
>
auto visit_impl(
F&&
) const -> typename std::enable_if<
(I == sizeof...(Types)) &&
! std::is_same<void, return_type_t<F, const T0&>>::value,
return_type_t<F, const T0&>
>::type
{
return return_type_t<F, const T0&>{};
}
template<
size_t I,
typename F
>
auto visit_impl(
F&& f
) const -> typename std::enable_if<
(I < sizeof...(Types)),
return_type_t<F, const T0&>
>::type
{
if (type_identity == (I+1))
return std::forward<F>(f)(unchecked_get<get_type_t<I>>());
else
return visit_impl<I+1>(std::forward<F>(f));
} }
template <typename T> template <typename T>
...@@ -262,7 +199,7 @@ namespace dlib ...@@ -262,7 +199,7 @@ namespace dlib
void destruct () void destruct ()
{ {
visit(destruct_helper{}); apply_to_contents(destruct_helper{});
type_identity = 0; type_identity = 0;
} }
...@@ -352,7 +289,7 @@ namespace dlib ...@@ -352,7 +289,7 @@ namespace dlib
const type_safe_union& item const type_safe_union& item
) : type_safe_union() ) : type_safe_union()
{ {
item.visit(assign_to{*this}); item.apply_to_contents(assign_to{*this});
} }
type_safe_union& operator=( type_safe_union& operator=(
...@@ -362,7 +299,7 @@ namespace dlib ...@@ -362,7 +299,7 @@ namespace dlib
if (item.is_empty()) if (item.is_empty())
destruct(); destruct();
else else
item.visit(assign_to{*this}); item.apply_to_contents(assign_to{*this});
return *this; return *this;
} }
...@@ -370,7 +307,7 @@ namespace dlib ...@@ -370,7 +307,7 @@ namespace dlib
type_safe_union&& item type_safe_union&& item
) : type_safe_union() ) : type_safe_union()
{ {
item.visit(move_to{*this}); item.apply_to_contents(move_to{*this});
item.destruct(); item.destruct();
} }
...@@ -384,7 +321,7 @@ namespace dlib ...@@ -384,7 +321,7 @@ namespace dlib
} }
else else
{ {
item.visit(move_to{*this}); item.apply_to_contents(move_to{*this});
item.destruct(); item.destruct();
} }
return *this; return *this;
...@@ -423,7 +360,7 @@ namespace dlib ...@@ -423,7 +360,7 @@ namespace dlib
Args&&... args Args&&... args
) )
{ {
construct<T,Args...>(std::forward<Args>(args)...); construct<T>(std::forward<Args>(args)...);
} }
~type_safe_union() ~type_safe_union()
...@@ -445,23 +382,7 @@ namespace dlib ...@@ -445,23 +382,7 @@ namespace dlib
Args&&... args Args&&... args
) )
{ {
construct<T,Args...>(std::forward<Args>(args)...); construct<T>(std::forward<Args>(args)...);
}
template <typename F>
auto visit(
F&& f
) -> decltype(visit_impl<0>(std::forward<F>(f)))
{
return visit_impl<0>(std::forward<F>(f));
}
template <typename F>
auto visit(
F&& f
) const -> decltype(visit_impl<0>(std::forward<F>(f)))
{
return visit_impl<0>(std::forward<F>(f));
} }
template <typename F> template <typename F>
...@@ -469,7 +390,7 @@ namespace dlib ...@@ -469,7 +390,7 @@ namespace dlib
F&& f F&& f
) )
{ {
visit(std::forward<F>(f)); apply_to_contents_impl(std::forward<F>(f), *this, dlib::make_index_sequence<sizeof...(Types)>{});
} }
template <typename F> template <typename F>
...@@ -477,7 +398,7 @@ namespace dlib ...@@ -477,7 +398,7 @@ namespace dlib
F&& f F&& f
) const ) const
{ {
visit(std::forward<F>(f)); apply_to_contents_impl(std::forward<F>(f), *this, dlib::make_index_sequence<sizeof...(Types)>{});
} }
template <typename T> template <typename T>
...@@ -504,16 +425,16 @@ namespace dlib ...@@ -504,16 +425,16 @@ namespace dlib
{ {
if (type_identity == item.type_identity) if (type_identity == item.type_identity)
{ {
item.visit(swap_to{*this}); item.apply_to_contents(swap_to{*this});
} }
else if (is_empty()) else if (is_empty())
{ {
item.visit(move_to{*this}); item.apply_to_contents(move_to{*this});
item.destruct(); item.destruct();
} }
else if (item.is_empty()) else if (item.is_empty())
{ {
visit(move_to{item}); apply_to_contents(move_to{item});
destruct(); destruct();
} }
else else
...@@ -537,6 +458,16 @@ namespace dlib ...@@ -537,6 +458,16 @@ namespace dlib
return unchecked_get<T>(); return unchecked_get<T>();
} }
template <
typename T
>
T& get(
in_place_tag<T>
)
{
return get<T>();
}
template < template <
typename T, typename T,
is_valid_check<T> = true is_valid_check<T> = true
...@@ -572,52 +503,134 @@ namespace dlib ...@@ -572,52 +503,134 @@ namespace dlib
namespace detail namespace detail
{ {
struct serialize_helper template<
typename F,
typename TSU,
std::size_t... I
>
void for_each_type_impl(
F&& f,
TSU&& tsu,
dlib::index_sequence<I...>
)
{ {
serialize_helper(std::ostream& out_) : out(out_) {} using Tsu = typename std::decay<TSU>::type;
(void)std::initializer_list<int>{
(std::forward<F>(f)(
in_place_tag<type_safe_union_alternative_t<I, Tsu>>{},
std::forward<TSU>(tsu)),
0
)...
};
}
template <typename T> template<
void operator() (const T& item) const typename R,
typename F,
typename TSU,
std::size_t I
>
R visit_impl_as_type(
F&& f,
TSU&& tsu
)
{ {
serialize(item, out); using Tsu = typename std::decay<TSU>::type;
using T = type_safe_union_alternative_t<I, Tsu>;
return dlib::invoke(std::forward<F>(f), tsu.template cast_to<T>());
} }
std::ostream& out;
};
template< template<
size_t I, typename R,
typename... Types typename F,
typename TSU,
std::size_t... I
> >
inline typename std::enable_if<(I == sizeof...(Types))>::type deserialize_helper( R visit_impl(
std::istream&, F&& f,
int, TSU&& tsu,
type_safe_union<Types...>& dlib::index_sequence<I...>
) )
{ {
using func_t = R(*)(F&&, TSU&&);
const func_t vtable[] = {
/*! Empty (type_identity == 0) case !*/
[](F&&, TSU&&) {
return R();
},
/*! Non-empty cases !*/
&visit_impl_as_type<R,F&&,TSU&&,I>...
};
return vtable[tsu.get_current_type_id()](std::forward<F>(f), std::forward<TSU>(tsu));
}
} }
template< template<
size_t I, typename TSU,
typename... Types typename F
> >
inline typename std::enable_if<(I < sizeof...(Types))>::type deserialize_helper( void for_each_type(
std::istream& in, F&& f,
int index, TSU&& tsu
type_safe_union<Types...>& x
) )
{ {
using T = typename internal::variant_get_type<I, Types...>::type; using Tsu = typename std::decay<TSU>::type;
static constexpr std::size_t Size = type_safe_union_size<Tsu>::value;
detail::for_each_type_impl(std::forward<F>(f), std::forward<TSU>(tsu), dlib::make_index_sequence<Size>{});
}
if (index == (I+1)) template<
typename F,
typename TSU,
typename Tsu = typename std::decay<TSU>::type,
typename T0 = type_safe_union_alternative_t<0, Tsu>
>
auto visit(
F&& f,
TSU&& tsu
) -> dlib::invoke_result_t<F, decltype(tsu.template cast_to<T0>())>
{ {
deserialize(x.template get<T>(), in); using ReturnType = dlib::invoke_result_t<F, decltype(tsu.template cast_to<T0>())>;
static constexpr std::size_t Size = type_safe_union_size<Tsu>::value;
return detail::visit_impl<ReturnType>(std::forward<F>(f), std::forward<TSU>(tsu), dlib::make_index_sequence<Size>{});
} }
else
namespace detail
{
struct serialize_helper
{ {
deserialize_helper<I+1>(in, index, x); serialize_helper(std::ostream& out_) : out(out_) {}
template <typename T>
void operator() (const T& item) const
{
serialize(item, out);
} }
std::ostream& out;
};
struct deserialize_helper
{
deserialize_helper(
std::istream& in_,
int index_
) : index(index_),
in(in_)
{}
template<typename T, typename TSU>
void operator()(in_place_tag<T>, TSU&& x)
{
if (index == x.template get_type_id<T>())
deserialize(x.template get<T>(), in);
} }
const int index = -1;
std::istream& in;
};
} // namespace detail } // namespace detail
template<typename... Types> template<typename... Types>
...@@ -629,7 +642,7 @@ namespace dlib ...@@ -629,7 +642,7 @@ namespace dlib
try try
{ {
serialize(item.get_current_type_id(), out); serialize(item.get_current_type_id(), out);
item.visit(detail::serialize_helper(out)); item.apply_to_contents(detail::serialize_helper(out));
} }
catch (serialization_error& e) catch (serialization_error& e)
{ {
...@@ -651,7 +664,7 @@ namespace dlib ...@@ -651,7 +664,7 @@ namespace dlib
if (index == 0) if (index == 0)
item.clear(); item.clear();
else if (index > 0 && index <= (int)sizeof...(Types)) else if (index > 0 && index <= (int)sizeof...(Types))
detail::deserialize_helper<0>(in, index, item); for_each_type(detail::deserialize_helper(in, index), item);
else else
throw serialization_error("bad index value. Should be in range [0,sizeof...(Types))"); throw serialization_error("bad index value. Should be in range [0,sizeof...(Types))");
} }
......
...@@ -19,7 +19,7 @@ namespace dlib ...@@ -19,7 +19,7 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template<typename T> template<typename T>
struct in_place_tag {}; struct in_place_tag { using type = T; };
/*! /*!
This is an empty class type used as a special disambiguation tag to be This is an empty class type used as a special disambiguation tag to be
passed as the first argument to the constructor of type_safe_union that performs passed as the first argument to the constructor of type_safe_union that performs
...@@ -36,8 +36,42 @@ namespace dlib ...@@ -36,8 +36,42 @@ namespace dlib
}; };
using tsu = type_safe_union<A,std::string>; using tsu = type_safe_union<A,std::string>;
tsu a(in_place_tag<A>{}, 0, 1); // a now contains an object of type A
tsu a(in_place_tag<A>{}, 0, 1); It is also used with type_safe_union::for_each() to disambiguate types.
!*/
// ----------------------------------------------------------------------------------------
template<typename TSU>
struct type_safe_union_size
{
static constexpr size_t value = The number of types in the TSU.
};
/*!
requires
- TSU must be of type type_safe_union<Types...> with possible cv qualification
ensures
- value contains the number of types in TSU, i.e. sizeof...(Types...)
!*/
// ----------------------------------------------------------------------------------------
template <size_t I, typename TSU>
struct type_safe_union_alternative;
/*!
requires
- TSU is a type_safe_union
ensures
- type_safe_union_alternative<I, TSU>::type is the Ith type in the TSU.
- TSU::get_type_id<typename type_safe_union_alternative<I, TSU>::type>() == I
!*/
template<size_t I, typename TSU>
using type_safe_union_alternative_t = type_safe_union_alternative<I,TSU>::type;
/*!
ensures
- provides template alias for type_safe_union_alternative
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -195,6 +229,8 @@ namespace dlib ...@@ -195,6 +229,8 @@ namespace dlib
- if (T is the same type as one of the template arguments) then - if (T is the same type as one of the template arguments) then
- returns a number indicating which template argument it is. In particular, - returns a number indicating which template argument it is. In particular,
if it's the first template argument it returns 1, if the second then 2, and so on. if it's the first template argument it returns 1, if the second then 2, and so on.
- else if (T is in_place_tag<U>) then
- equivalent to returning get_type_id<U>())
- else - else
- returns -1 - returns -1
!*/ !*/
...@@ -224,14 +260,16 @@ namespace dlib ...@@ -224,14 +260,16 @@ namespace dlib
) const; ) const;
/*! /*!
ensures ensures
- returns type_identity, i.e, the index of the currently held type. - if (is_empty()) then
For example if the current type is the first template argument it returns 1, if it's the second then 2, and so on.
If the current object is empty, i.e. is_empty() == true, then
- returns 0 - returns 0
- else
- Returns the type id of the currently held type. This is the same as
get_type_id<WhateverTypeIsCurrentlyHeld>(). Therefore, if the current type is
the first template argument it returns 1, if it's the second then 2, and so on.
!*/ !*/
template <typename F> template <typename F>
auto visit( void apply_to_contents(
F&& f F&& f
); );
/*! /*!
...@@ -242,12 +280,12 @@ namespace dlib ...@@ -242,12 +280,12 @@ namespace dlib
ensures ensures
- if (is_empty() == false) then - if (is_empty() == false) then
- Let U denote the type of object currently contained in this type_safe_union - Let U denote the type of object currently contained in this type_safe_union
- returns std::forward<F>(f)(this->get<U>()) - calls std::forward<F>(f)(this->get<U>())
- The object passed to f() (i.e. by this->get<U>()) will be non-const. - The object passed to f() (i.e. by this->get<U>()) will be non-const.
!*/ !*/
template <typename F> template <typename F>
auto visit( void apply_to_contents(
F&& f F&& f
) const; ) const;
/*! /*!
...@@ -258,28 +296,10 @@ namespace dlib ...@@ -258,28 +296,10 @@ namespace dlib
ensures ensures
- if (is_empty() == false) then - if (is_empty() == false) then
- Let U denote the type of object currently contained in this type_safe_union - Let U denote the type of object currently contained in this type_safe_union
- returns std::forward<F>(f)(this->get<U>()) - calls std::forward<F>(f)(this->get<U>())
- The object passed to f() (i.e. by this->get<U>()) will be const. - The object passed to f() (i.e. by this->get<U>()) will be const.
!*/ !*/
template <typename F>
void apply_to_contents(
F&& f
);
/*!
ensures:
equivalent to calling visit(std::forward<F>(f)) with void return type
!*/
template <typename F>
void apply_to_contents(
F&& f
) const;
/*!
ensures:
equivalent to calling visit(std::forward<F>(f)) with void return type
!*/
template <typename T> template <typename T>
T& get( T& get(
); );
...@@ -298,6 +318,15 @@ namespace dlib ...@@ -298,6 +318,15 @@ namespace dlib
- returns a non-const reference to the newly created T object. - returns a non-const reference to the newly created T object.
!*/ !*/
template <typename T>
T& get(
const in_place_tag<T>& tag
);
/*!
ensures
- equivalent to calling get<T>()
!*/
template <typename T> template <typename T>
const T& cast_to ( const T& cast_to (
) const; ) const;
...@@ -345,6 +374,51 @@ namespace dlib ...@@ -345,6 +374,51 @@ namespace dlib
provides a global swap function provides a global swap function
!*/ !*/
// ----------------------------------------------------------------------------------------
template<
typename TSU,
typename F
>
void for_each_type(
F&& f,
TSU&& tsu
);
/*!
requires
- tsu is an object of type type_safe_union<Types...> for some types Types...
- f is a callable object such that the following expression is valid for
all types U in Types...:
std::forward<F>(f)(in_place_tag<U>{}, std::forward<TSU>(tsu))
ensures
- This function iterates over all types U in Types... and calls:
std::forward<F>(f)(in_place_tag<U>{}, std::forward<TSU>(tsu))
!*/
// ----------------------------------------------------------------------------------------
template<
typename F,
typename TSU
>
auto visit(
F&& f,
TSU&& tsu
);
/*!
requires
- tsu is an object of type type_safe_union<Types...> for some types Types...
- f is a callable object capable of operating on all the types contained
in tsu. I.e. std::forward<F>(f)(this->get<U>()) must be a valid
expression for all the possible U types.
ensures
- if (tsu.is_empty() == false) then
- Let U denote the type of object currently contained in tsu.
- returns std::forward<F>(f)(this->get<U>())
- The object passed to f() (i.e. by this->get<U>()) will have the same reference
type as TSU.
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template<typename... Types> template<typename... Types>
...@@ -381,7 +455,7 @@ namespace dlib ...@@ -381,7 +455,7 @@ namespace dlib
} }
/*! /*!
This is a helper function for passing many callable objects (usually lambdas) This is a helper function for passing many callable objects (usually lambdas)
to either apply_to_contents() or visit(), that combine to make a complete to either apply_to_contents(), visit() or for_each(), that combine to make a complete
visitor. A picture paints a thousand words: visitor. A picture paints a thousand words:
using tsu = type_safe_union<int,float,std::string>; using tsu = type_safe_union<int,float,std::string>;
...@@ -405,7 +479,7 @@ namespace dlib ...@@ -405,7 +479,7 @@ namespace dlib
assert(result == "hello there"); assert(result == "hello there");
result = ""; result = "";
result = a.visit(overloaded( result = visit(overloaded(
[](int) { [](int) {
return std::string("int"); return std::string("int");
}, },
...@@ -415,9 +489,25 @@ namespace dlib ...@@ -415,9 +489,25 @@ namespace dlib
[](const std::string& item) { [](const std::string& item) {
return item; return item;
} }
)); ), a);
assert(result == "hello there"); assert(result == "hello there");
std::vector<int> type_ids;
for_each_type(a, overloaded(
[&type_ids](in_place_tag<int>, tsu& me) {
type_ids.push_back(me.get_type_id<int>());
},
[&type_ids](in_place_tag<float>, tsu& me) {
type_ids.push_back(me.get_type_id<float>());
},
[&type_ids](in_place_tag<std::string>, tsu& me) {
type_ids.push_back(me.get_type_id<std::string>());
}
));
assert(type_ids == vector<int>({0,1,2}));
!*/ !*/
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment