"examples/git@developer.sourcefind.cn:modelzoo/bert4torch.git" did not exist on "58935387458ba1dee228fbd261bf9878a5d7ae32"
Unverified Commit d9e58d66 authored by pfeatherstone's avatar pfeatherstone Committed by GitHub
Browse files

Fixes bug when (de)serializing vector<complex<float>> (#2244)



* [SERIALIZATION] fixed bug when (de)serializing vector<complex<float>>. DLIB_DEFINE... macro uses __out and __in variables names for ostream and istream objects respectively to avoid member variable name conflicts.

* Refactoring objects in DLIB_DEFINE_DEFAULT_SERIALIZATION to avoid name conflicts with user types

* Refactoring objects in DLIB_DEFINE_DEFAULT_SERIALIZATION to avoid name conflicts with user types

* removed tabs

* removed more tabs
Co-authored-by: default avatarpf <pf@pf-ubuntu-dev>
parent a7627cbd
...@@ -770,6 +770,50 @@ namespace dlib ...@@ -770,6 +770,50 @@ namespace dlib
deserialize_floating_point(item,in); deserialize_floating_point(item,in);
} }
// ----------------------------------------------------------------------------------------
template <
typename T
>
inline void serialize (
const std::complex<T>& item,
std::ostream& out
)
{
try
{
serialize(item.real(),out);
serialize(item.imag(),out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing an object of type std::complex");
}
}
// ----------------------------------------------------------------------------------------
template <
typename T
>
inline void deserialize (
std::complex<T>& item,
std::istream& in
)
{
try
{
T real, imag;
deserialize(real,in);
deserialize(imag,in);
item = std::complex<T>(real,imag);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing an object of type std::complex");
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// prototypes // prototypes
...@@ -2118,50 +2162,6 @@ namespace dlib ...@@ -2118,50 +2162,6 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
template <
typename T
>
inline void serialize (
const std::complex<T>& item,
std::ostream& out
)
{
try
{
serialize(item.real(),out);
serialize(item.imag(),out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing an object of type std::complex");
}
}
// ----------------------------------------------------------------------------------------
template <
typename T
>
inline void deserialize (
std::complex<T>& item,
std::istream& in
)
{
try
{
T real, imag;
deserialize(real,in);
deserialize(imag,in);
item = std::complex<T>(real,imag);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing an object of type std::complex");
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename T, typename deleter> template <typename T, typename deleter>
...@@ -2555,7 +2555,7 @@ namespace dlib ...@@ -2555,7 +2555,7 @@ namespace dlib
} }
#define DLIB_DEFINE_DEFAULT_SERIALIZATION(Type, ...) \ #define DLIB_DEFINE_DEFAULT_SERIALIZATION(Type, ...) \
void serialize_to(std::ostream& out) const \ void serialize_to(std::ostream& dlibDefaultSer$_out) const \
{ \ { \
using dlib::serialize; \ using dlib::serialize; \
using dlib::serialize_these; \ using dlib::serialize_these; \
...@@ -2565,9 +2565,9 @@ namespace dlib ...@@ -2565,9 +2565,9 @@ namespace dlib
/* you realize you need to change the serialization */ \ /* you realize you need to change the serialization */ \
/* format you can identify which version of the format */ \ /* format you can identify which version of the format */ \
/* you are encountering when reading old files. */ \ /* you are encountering when reading old files. */ \
int version = 1; \ int dlibDefaultSer$_version = 1; \
serialize(version, out); \ serialize(dlibDefaultSer$_version, dlibDefaultSer$_out); \
serialize_these(out, __VA_ARGS__); \ serialize_these(dlibDefaultSer$_out, __VA_ARGS__); \
} \ } \
catch (dlib::serialization_error& e) \ catch (dlib::serialization_error& e) \
{ \ { \
...@@ -2575,17 +2575,17 @@ namespace dlib ...@@ -2575,17 +2575,17 @@ namespace dlib
} \ } \
} \ } \
\ \
void deserialize_from(std::istream& in) \ void deserialize_from(std::istream& dlibDefaultSer$_in) \
{ \ { \
using dlib::deserialize; \ using dlib::deserialize; \
using dlib::deserialize_these; \ using dlib::deserialize_these; \
try \ try \
{ \ { \
int version = 0; \ int dlibDefaultSer$_version = 0; \
deserialize(version, in); \ deserialize(dlibDefaultSer$_version, dlibDefaultSer$_in); \
if (version != 1) \ if (dlibDefaultSer$_version != 1) \
throw dlib::serialization_error("Unexpected version found while deserializing " #Type); \ throw dlib::serialization_error("Unexpected version found while deserializing " #Type); \
deserialize_these(in, __VA_ARGS__); \ deserialize_these(dlibDefaultSer$_in, __VA_ARGS__); \
} \ } \
catch (dlib::serialization_error& e) \ catch (dlib::serialization_error& e) \
{ \ { \
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <ctime> #include <ctime>
#include <dlib/serialize.h> #include <dlib/serialize.h>
#include <dlib/image_transforms.h> #include <dlib/image_transforms.h>
#include <dlib/rand.h>
#include "tester.h" #include "tester.h"
...@@ -432,15 +433,16 @@ namespace ...@@ -432,15 +433,16 @@ namespace
std::unordered_multiset<string> o; std::unordered_multiset<string> o;
std::shared_ptr<string> ptr_shared1; std::shared_ptr<string> ptr_shared1;
std::shared_ptr<string> ptr_shared2; std::shared_ptr<string> ptr_shared2;
std::vector<std::complex<double>> p;
bool operator==(const my_custom_type& rhs) const bool operator==(const my_custom_type& rhs) const
{ {
return std::tie(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o) == std::tie(rhs.a,rhs.b,rhs.c,rhs.d,rhs.e,rhs.f,rhs.g,rhs.h,rhs.i,rhs.j,rhs.k,rhs.l,rhs.m,rhs.n,rhs.o) return std::tie(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p) == std::tie(rhs.a,rhs.b,rhs.c,rhs.d,rhs.e,rhs.f,rhs.g,rhs.h,rhs.i,rhs.j,rhs.k,rhs.l,rhs.m,rhs.n,rhs.o,rhs.p)
&& pointers_values_equal(ptr_shared1, rhs.ptr_shared1) && pointers_values_equal(ptr_shared1, rhs.ptr_shared1)
&& pointers_values_equal(ptr_shared2, rhs.ptr_shared2); && pointers_values_equal(ptr_shared2, rhs.ptr_shared2);
} }
DLIB_DEFINE_DEFAULT_SERIALIZATION(my_custom_type, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, ptr_shared1, ptr_shared2); DLIB_DEFINE_DEFAULT_SERIALIZATION(my_custom_type, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, ptr_shared1, ptr_shared2);
}; };
struct my_custom_type_array struct my_custom_type_array
...@@ -1119,6 +1121,9 @@ namespace ...@@ -1119,6 +1121,9 @@ namespace
t1.o.insert("hello from unordered_multiset"); t1.o.insert("hello from unordered_multiset");
t1.o.insert("hello from unordered_multiset"); t1.o.insert("hello from unordered_multiset");
t1.ptr_shared1 = make_shared<string>("hello from shared_ptr"); t1.ptr_shared1 = make_shared<string>("hello from shared_ptr");
dlib::rand rng(std::time(NULL));
for (int i = 0 ; i < 1024 ; i++)
t1.p.push_back(rng.get_random_gaussian());
t2.a = 2; t2.a = 2;
t2.b = 4.0; t2.b = 4.0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment