Unverified Commit aaac87a2 authored by pfeatherstone's avatar pfeatherstone Committed by GitHub
Browse files

[TYPE_SAFE_UNION] simplified some type traits and added for_each(). (#2475)

* [TYPE_SAFE_UNION] simplified some type traits and added for_each().

* added example serialization/deserialization using typeid().hash_code and different type_safe_union types

* in_place_tag is an empty struct. so don't pass const references, you're unecessarily passing 8 bytes around for now reason

* - added variant_size for type_safe_union
- added variant_alternative for type_safe_union
- removed for type_safe_union::for_each() and replaced with global function dlib::for_each_type()

* - made visit() a global

* use dlib::invoke explicitly

* - for_each_type is implemented using fold expression (or whatever the right term is) instead of template recursion. This method, in theory, yields better compile times. And if you're familiar with parameter packs, then the implementation is easier to read.

* - refactoring
- reordered function parameters in for_each_type()
- vtable implementation of apply_to_contents and visit() (sorry Davis for yet another change)

* add option to not zero out gradients and method to do it (#2477)

* Avoid different kinds of compiler warnings (#2481)

* Avoid different kinds of compiler warnings that started to appear when upgrading my build environment

* Avoid more compiler warnings

* Revert the overly verbose static_cast changes

* Make resize_bilinear and resize_bilinear_gradient take long long (previously just long)

* Circumvent what appears to be a bug in Visual Studio 2019's optimizer
(see: https://forum.juce.com/t/warning-in-the-lastest-vs2019/38267

)

* Fix MSVC pragma warnings with other compilers (#2483)

* Fix warning about unused zero_gradients parameter (#2487)

* Fix warning about unused zero_gradients parameter

* match signature of other methods

* cleanup
Co-authored-by: default avatarpfeatherstone <peter@me>
Co-authored-by: default avatarAdrià Arrufat <1671644+arrufat@users.noreply.github.com>
Co-authored-by: default avatarJuha Reunanen <juha.reunanen@tomaattinen.com>
Co-authored-by: default avatarDavis King <davis@dlib.net>
parent 42e08696
...@@ -80,6 +80,13 @@ namespace ...@@ -80,6 +80,13 @@ namespace
public: public:
void test_stuff() void test_stuff()
{ {
static_assert(tsu::get_type_id<float>() == 1, "bad type id");
static_assert(tsu::get_type_id<double>() == 2, "bad type id");
static_assert(tsu::get_type_id<char>() == 3, "bad type id");
static_assert(tsu::get_type_id<std::string>() == 4, "bad type id");
static_assert(tsu::get_type_id<long>() == -1, "This should be -1");
DLIB_TEST(a.is_empty() == true); DLIB_TEST(a.is_empty() == true);
DLIB_TEST(a.contains<char>() == false); DLIB_TEST(a.contains<char>() == false);
DLIB_TEST(a.contains<float>() == false); DLIB_TEST(a.contains<float>() == false);
...@@ -566,7 +573,7 @@ namespace ...@@ -566,7 +573,7 @@ namespace
DLIB_TEST(b.is_empty()); DLIB_TEST(b.is_empty());
DLIB_TEST(b.get_current_type_id() == 0); DLIB_TEST(b.get_current_type_id() == 0);
//visit can return non-void types //visit can return non-void types
auto ret = a.visit(overloaded( auto ret = visit(overloaded(
[](int) { [](int) {
return std::string("int"); return std::string("int");
}, },
...@@ -576,7 +583,7 @@ namespace ...@@ -576,7 +583,7 @@ namespace
[](const std::string&) { [](const std::string&) {
return std::string("std::string"); return std::string("std::string");
} }
)); ), a);
static_assert(std::is_same<std::string, decltype(ret)>::value, "bad return type"); static_assert(std::is_same<std::string, decltype(ret)>::value, "bad return type");
DLIB_TEST(ret == "int"); DLIB_TEST(ret == "int");
//apply_to_contents can only return void //apply_to_contents can only return void
...@@ -604,7 +611,7 @@ namespace ...@@ -604,7 +611,7 @@ namespace
tsu_b object(dlib::in_place_tag<tsu_a>{}, std::string("hello from bottom node")); tsu_b object(dlib::in_place_tag<tsu_a>{}, std::string("hello from bottom node"));
DLIB_TEST(object.contains<tsu_a>()); DLIB_TEST(object.contains<tsu_a>());
DLIB_TEST(object.get<tsu_a>().get<std::string>() == "hello from bottom node"); DLIB_TEST(object.get<tsu_a>().get<std::string>() == "hello from bottom node");
auto ret = object.visit(overloaded( auto ret = visit(overloaded(
[](int) { [](int) {
return std::string("int"); return std::string("int");
}, },
...@@ -615,7 +622,7 @@ namespace ...@@ -615,7 +622,7 @@ namespace
return std::string("std::string"); return std::string("std::string");
}, },
[](const tsu_a& item) { [](const tsu_a& item) {
return item.visit(overloaded( return visit( overloaded(
[](int) { [](int) {
return std::string("nested int"); return std::string("nested int");
}, },
...@@ -625,20 +632,19 @@ namespace ...@@ -625,20 +632,19 @@ namespace
[](std::string str) { [](std::string str) {
return str; return str;
} }
)); ), item);
} }
)); ), object);
static_assert(std::is_same<std::string, decltype(ret)>::value, "bad type"); static_assert(std::is_same<std::string, decltype(ret)>::value, "bad type");
DLIB_TEST(ret == "hello from bottom node"); DLIB_TEST(ret == "hello from bottom node");
} }
{ {
//"private" visitor //struct visitor
using tsu = type_safe_union<int,float,std::string>; using tsu = type_safe_union<int,float,std::string>;
class visitor_private struct visitor_private
{ {
private:
std::string operator()(int) std::string operator()(int)
{ {
return std::string("int"); return std::string("int");
...@@ -653,19 +659,162 @@ namespace ...@@ -653,19 +659,162 @@ namespace
{ {
return str; return str;
} }
friend tsu;
}; };
visitor_private visitor; visitor_private visitor;
tsu a = std::string("hello from private visitor"); tsu a = std::string("hello from private visitor");
auto ret = a.visit(visitor); auto ret = visit(visitor, a);
static_assert(std::is_same<std::string, decltype(ret)>::value, "bad type"); static_assert(std::is_same<std::string, decltype(ret)>::value, "bad type");
DLIB_TEST(ret == "hello from private visitor"); DLIB_TEST(ret == "hello from private visitor");
} }
} }
}; };
namespace test_for_each_1
{
/*! Local classes aren't allowed to have template member functions... !*/
using tsu = type_safe_union<int,float,std::string>;
static_assert(type_safe_union_size<tsu>::value == 3, "bad number of types");
static_assert(std::is_same<type_safe_union_alternative_t<0, tsu>, int>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<1, tsu>, float>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<2, tsu>, std::string>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<0, const tsu>, const int>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<1, const tsu>, const float>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<2, const tsu>, const std::string>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<0, volatile tsu>, volatile int>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<1, volatile tsu>, volatile float>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<2, volatile tsu>, volatile std::string>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<0, const volatile tsu>, const volatile int>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<1, const volatile tsu>, const volatile float>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<2, const volatile tsu>, const volatile std::string>::value, "bad type");
struct for_each_visitor
{
std::vector<int> type_indices;
template<typename T>
void operator()(dlib::in_place_tag<T>, const tsu& item)
{
type_indices.push_back(item.get_type_id<T>());
}
};
void test()
{
tsu a;
for_each_visitor visitor;
for_each_type(visitor, a);
DLIB_TEST(visitor.type_indices.size() == 3);
DLIB_TEST(visitor.type_indices[0] == 1);
DLIB_TEST(visitor.type_indices[1] == 2);
DLIB_TEST(visitor.type_indices[2] == 3);
}
}
namespace test_for_each_2
{
/*! Local classes aren't allowed to have template member functions... !*/
using tsu = type_safe_union<int,float,std::string>;
//for_each() that demonstrates an actual use-case.
//Instead of something simple like a target index, you might want to set the variant
//based on some specific state that is unique to one of the alternative types.
//For example, you might want to conditionally set the variant based on hashes.
struct for_each_visitor
{
for_each_visitor(int target_index_) : target_index(target_index_) {}
template<typename TagType>
void operator()(TagType tag, tsu& item)
{
if (item.get_type_id<TagType>() == target_index)
item = tsu{tag};
}
const int target_index = 0;
};
void test()
{
tsu a;
for_each_type(for_each_visitor{1}, a);
DLIB_TEST(a.contains<int>());
a.clear();
for_each_type(for_each_visitor{2}, a);
DLIB_TEST(a.contains<float>());
a.clear();
for_each_type(for_each_visitor{3}, a);
DLIB_TEST(a.contains<std::string>());
a.clear();
for_each_type(for_each_visitor{-1}, a);
DLIB_TEST(a.is_empty());
a.clear();
for_each_type(for_each_visitor{215465}, a);
DLIB_TEST(a.is_empty());
}
}
namespace test_for_each_3
{
using tsu1 = type_safe_union<int,float,std::string>;
using tsu2 = type_safe_union<std::string,long,char>;
struct serializer_typeid
{
serializer_typeid(std::ostream& out_) : out(out_) {}
template<typename T>
void operator()(const T& x)
{
dlib::serialize(typeid(T).hash_code(), out);
dlib::serialize(x, out);
}
std::ostream& out;
};
struct deserializer_typeid
{
deserializer_typeid(std::istream& in_) : in(in_)
{
dlib::deserialize(hash_code, in);
}
template<typename T, typename TSU>
void operator()(in_place_tag<T>, TSU&& me)
{
if (typeid(T).hash_code() == hash_code)
dlib::deserialize(me.template get<T>(), in);
}
std::size_t hash_code = 0;
std::istream& in;
};
void test()
{
tsu1 a;
a.get<std::string>() = "hello from tsu1";
std::stringstream out;
visit(serializer_typeid(out), a);
tsu2 b;
for_each_type(deserializer_typeid(out), b);
DLIB_TEST(b.contains<std::string>());
DLIB_TEST(b.get<std::string>() == "hello from tsu1");
}
}
class type_safe_union_tester : public tester class type_safe_union_tester : public tester
{ {
public: public:
...@@ -682,6 +831,9 @@ namespace ...@@ -682,6 +831,9 @@ namespace
{ {
test a; test a;
a.test_stuff(); a.test_stuff();
test_for_each_1::test();
test_for_each_2::test();
test_for_each_3::test();
} }
} }
} a; } a;
......
...@@ -19,7 +19,7 @@ namespace dlib ...@@ -19,7 +19,7 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template<typename T> template<typename T>
struct in_place_tag {}; struct in_place_tag { using type = T; };
/*! /*!
This is an empty class type used as a special disambiguation tag to be This is an empty class type used as a special disambiguation tag to be
passed as the first argument to the constructor of type_safe_union that performs passed as the first argument to the constructor of type_safe_union that performs
...@@ -36,8 +36,42 @@ namespace dlib ...@@ -36,8 +36,42 @@ namespace dlib
}; };
using tsu = type_safe_union<A,std::string>; using tsu = type_safe_union<A,std::string>;
tsu a(in_place_tag<A>{}, 0, 1); // a now contains an object of type A
tsu a(in_place_tag<A>{}, 0, 1); It is also used with type_safe_union::for_each() to disambiguate types.
!*/
// ----------------------------------------------------------------------------------------
template<typename TSU>
struct type_safe_union_size
{
static constexpr size_t value = The number of types in the TSU.
};
/*!
requires
- TSU must be of type type_safe_union<Types...> with possible cv qualification
ensures
- value contains the number of types in TSU, i.e. sizeof...(Types...)
!*/
// ----------------------------------------------------------------------------------------
template <size_t I, typename TSU>
struct type_safe_union_alternative;
/*!
requires
- TSU is a type_safe_union
ensures
- type_safe_union_alternative<I, TSU>::type is the Ith type in the TSU.
- TSU::get_type_id<typename type_safe_union_alternative<I, TSU>::type>() == I
!*/
template<size_t I, typename TSU>
using type_safe_union_alternative_t = type_safe_union_alternative<I,TSU>::type;
/*!
ensures
- provides template alias for type_safe_union_alternative
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -195,6 +229,8 @@ namespace dlib ...@@ -195,6 +229,8 @@ namespace dlib
- if (T is the same type as one of the template arguments) then - if (T is the same type as one of the template arguments) then
- returns a number indicating which template argument it is. In particular, - returns a number indicating which template argument it is. In particular,
if it's the first template argument it returns 1, if the second then 2, and so on. if it's the first template argument it returns 1, if the second then 2, and so on.
- else if (T is in_place_tag<U>) then
- equivalent to returning get_type_id<U>())
- else - else
- returns -1 - returns -1
!*/ !*/
...@@ -224,14 +260,16 @@ namespace dlib ...@@ -224,14 +260,16 @@ namespace dlib
) const; ) const;
/*! /*!
ensures ensures
- returns type_identity, i.e, the index of the currently held type. - if (is_empty()) then
For example if the current type is the first template argument it returns 1, if it's the second then 2, and so on.
If the current object is empty, i.e. is_empty() == true, then
- returns 0 - returns 0
- else
- Returns the type id of the currently held type. This is the same as
get_type_id<WhateverTypeIsCurrentlyHeld>(). Therefore, if the current type is
the first template argument it returns 1, if it's the second then 2, and so on.
!*/ !*/
template <typename F> template <typename F>
auto visit( void apply_to_contents(
F&& f F&& f
); );
/*! /*!
...@@ -242,12 +280,12 @@ namespace dlib ...@@ -242,12 +280,12 @@ namespace dlib
ensures ensures
- if (is_empty() == false) then - if (is_empty() == false) then
- Let U denote the type of object currently contained in this type_safe_union - Let U denote the type of object currently contained in this type_safe_union
- returns std::forward<F>(f)(this->get<U>()) - calls std::forward<F>(f)(this->get<U>())
- The object passed to f() (i.e. by this->get<U>()) will be non-const. - The object passed to f() (i.e. by this->get<U>()) will be non-const.
!*/ !*/
template <typename F> template <typename F>
auto visit( void apply_to_contents(
F&& f F&& f
) const; ) const;
/*! /*!
...@@ -258,28 +296,10 @@ namespace dlib ...@@ -258,28 +296,10 @@ namespace dlib
ensures ensures
- if (is_empty() == false) then - if (is_empty() == false) then
- Let U denote the type of object currently contained in this type_safe_union - Let U denote the type of object currently contained in this type_safe_union
- returns std::forward<F>(f)(this->get<U>()) - calls std::forward<F>(f)(this->get<U>())
- The object passed to f() (i.e. by this->get<U>()) will be const. - The object passed to f() (i.e. by this->get<U>()) will be const.
!*/ !*/
template <typename F>
void apply_to_contents(
F&& f
);
/*!
ensures:
equivalent to calling visit(std::forward<F>(f)) with void return type
!*/
template <typename F>
void apply_to_contents(
F&& f
) const;
/*!
ensures:
equivalent to calling visit(std::forward<F>(f)) with void return type
!*/
template <typename T> template <typename T>
T& get( T& get(
); );
...@@ -298,6 +318,15 @@ namespace dlib ...@@ -298,6 +318,15 @@ namespace dlib
- returns a non-const reference to the newly created T object. - returns a non-const reference to the newly created T object.
!*/ !*/
template <typename T>
T& get(
const in_place_tag<T>& tag
);
/*!
ensures
- equivalent to calling get<T>()
!*/
template <typename T> template <typename T>
const T& cast_to ( const T& cast_to (
) const; ) const;
...@@ -345,6 +374,51 @@ namespace dlib ...@@ -345,6 +374,51 @@ namespace dlib
provides a global swap function provides a global swap function
!*/ !*/
// ----------------------------------------------------------------------------------------
template<
typename TSU,
typename F
>
void for_each_type(
F&& f,
TSU&& tsu
);
/*!
requires
- tsu is an object of type type_safe_union<Types...> for some types Types...
- f is a callable object such that the following expression is valid for
all types U in Types...:
std::forward<F>(f)(in_place_tag<U>{}, std::forward<TSU>(tsu))
ensures
- This function iterates over all types U in Types... and calls:
std::forward<F>(f)(in_place_tag<U>{}, std::forward<TSU>(tsu))
!*/
// ----------------------------------------------------------------------------------------
template<
typename F,
typename TSU
>
auto visit(
F&& f,
TSU&& tsu
);
/*!
requires
- tsu is an object of type type_safe_union<Types...> for some types Types...
- f is a callable object capable of operating on all the types contained
in tsu. I.e. std::forward<F>(f)(this->get<U>()) must be a valid
expression for all the possible U types.
ensures
- if (tsu.is_empty() == false) then
- Let U denote the type of object currently contained in tsu.
- returns std::forward<F>(f)(this->get<U>())
- The object passed to f() (i.e. by this->get<U>()) will have the same reference
type as TSU.
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template<typename... Types> template<typename... Types>
...@@ -381,7 +455,7 @@ namespace dlib ...@@ -381,7 +455,7 @@ namespace dlib
} }
/*! /*!
This is a helper function for passing many callable objects (usually lambdas) This is a helper function for passing many callable objects (usually lambdas)
to either apply_to_contents() or visit(), that combine to make a complete to either apply_to_contents(), visit() or for_each(), that combine to make a complete
visitor. A picture paints a thousand words: visitor. A picture paints a thousand words:
using tsu = type_safe_union<int,float,std::string>; using tsu = type_safe_union<int,float,std::string>;
...@@ -405,7 +479,7 @@ namespace dlib ...@@ -405,7 +479,7 @@ namespace dlib
assert(result == "hello there"); assert(result == "hello there");
result = ""; result = "";
result = a.visit(overloaded( result = visit(overloaded(
[](int) { [](int) {
return std::string("int"); return std::string("int");
}, },
...@@ -415,9 +489,25 @@ namespace dlib ...@@ -415,9 +489,25 @@ namespace dlib
[](const std::string& item) { [](const std::string& item) {
return item; return item;
} }
)); ), a);
assert(result == "hello there"); assert(result == "hello there");
std::vector<int> type_ids;
for_each_type(a, overloaded(
[&type_ids](in_place_tag<int>, tsu& me) {
type_ids.push_back(me.get_type_id<int>());
},
[&type_ids](in_place_tag<float>, tsu& me) {
type_ids.push_back(me.get_type_id<float>());
},
[&type_ids](in_place_tag<std::string>, tsu& me) {
type_ids.push_back(me.get_type_id<std::string>());
}
));
assert(type_ids == vector<int>({0,1,2}));
!*/ !*/
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment