"vscode:/vscode.git/clone" did not exist on "fbf61f465b7756bd3d01a272ea994741c3cfcf8c"
Unverified Commit ebca1188 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Feature] Support serialization for smart pointer (#1291)

* fix script

* t

* fix weird bugs

* fix

* fix

* upload

* fix

* fix

* lint

* fix
parent 37d992ec
......@@ -15,14 +15,14 @@
namespace dmlc {
namespace serializer {
template<>
template <>
struct Handler<DLDataType> {
inline static void Write(Stream *strm, const DLDataType& dtype) {
inline static void Write(Stream *strm, const DLDataType &dtype) {
Handler<uint8_t>::Write(strm, dtype.code);
Handler<uint8_t>::Write(strm, dtype.bits);
Handler<uint16_t>::Write(strm, dtype.lanes);
}
inline static bool Read(Stream *strm, DLDataType* dtype) {
inline static bool Read(Stream *strm, DLDataType *dtype) {
if (!Handler<uint8_t>::Read(strm, &(dtype->code))) return false;
if (!Handler<uint8_t>::Read(strm, &(dtype->bits))) return false;
if (!Handler<uint16_t>::Read(strm, &(dtype->lanes))) return false;
......@@ -30,14 +30,14 @@ struct Handler<DLDataType> {
}
};
template<>
template <>
struct Handler<DLContext> {
inline static void Write(Stream *strm, const DLContext& ctx) {
inline static void Write(Stream *strm, const DLContext &ctx) {
int32_t device_type = static_cast<int32_t>(ctx.device_type);
Handler<int32_t>::Write(strm, device_type);
Handler<int32_t>::Write(strm, ctx.device_id);
}
inline static bool Read(Stream *strm, DLContext* ctx) {
inline static bool Read(Stream *strm, DLContext *ctx) {
int32_t device_type = 0;
if (!Handler<int32_t>::Read(strm, &(device_type))) return false;
ctx->device_type = static_cast<DLDeviceType>(device_type);
......
/*!
* Copyright (c) 2017 by Contributors
* \file dgl/runtime/serializer.h
* \brief Serializer extension to support DGL data types
* Include this file to enable serialization of DLDataType, DLContext
*/
#ifndef DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
#define DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
namespace dmlc {
namespace serializer {
//! \cond Doxygen_Suppress
template <typename T>
struct Handler<std::shared_ptr<T>> {
inline static void Write(Stream *strm, const std::shared_ptr<T> &data) {
Handler<T>::Write(strm, *data.get());
}
inline static bool Read(Stream *strm, std::shared_ptr<T> *data) {
// When read, the default initialization behavior of shared_ptr is
// shared_ptr<T>(), which is holding a nullptr. Here we need to manually
// reset to a real object for further loading
if (!(*data)) {
data->reset(new T());
}
return Handler<T>::Read(strm, data->get());
}
};
template <typename T>
struct Handler<std::unique_ptr<T>> {
inline static void Write(Stream *strm, const std::unique_ptr<T> &data) {
Handler<T>::Write(strm, *data.get());
}
inline static bool Read(Stream *strm, std::unique_ptr<T> *data) {
// When read, the default initialization behavior of unique_ptr is
// unique_ptr<T>(), which is holding a nullptr. Here we need to manually
// reset to a real object for further loading
if (!(*data)) {
data->reset(new T());
}
return Handler<T>::Read(strm, data->get());
}
};
} // namespace serializer
} // namespace dmlc
#endif // DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
#include <dgl/runtime/serializer.h>
#include <dgl/runtime/smart_ptr_serializer.h>
#include <dmlc/io.h>
#include <dmlc/logging.h>
#include <dmlc/memory_io.h>
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
#include <cstring>
#include <iostream>
#include <sstream>
#include <unordered_map>
using namespace std;
class MyClass {
public:
MyClass() {}
MyClass(std::string data) : data_(data) {}
inline void Save(dmlc::Stream *strm) const { strm->Write(this->data_); }
inline bool Load(dmlc::Stream *strm) { return strm->Read(&data_); }
inline bool operator==(const MyClass &other) const {
return data_ == other.data_;
}
public:
std::string data_;
};
// need to declare the traits property of my class to dmlc
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, MyClass, true);
}
template <typename T>
class SmartPtrTest : public ::testing::Test {
public:
typedef T SmartPtr;
};
using SmartPtrTypes =
::testing::Types<std::shared_ptr<MyClass>, std::unique_ptr<MyClass>>;
TYPED_TEST_SUITE(SmartPtrTest, SmartPtrTypes);
TYPED_TEST(SmartPtrTest, Obj_Test) {
std::string blob;
dmlc::MemoryStringStream fs(&blob);
using SmartPtr = typename TestFixture::SmartPtr;
auto myc = SmartPtr(new MyClass("1111"));
{ static_cast<dmlc::Stream *>(&fs)->Write(myc); }
fs.Seek(0);
auto copy_data = SmartPtr(new MyClass());
CHECK(static_cast<dmlc::Stream *>(&fs)->Read(&copy_data));
EXPECT_EQ(myc->data_, copy_data->data_);
}
TYPED_TEST(SmartPtrTest, Vector_Test1) {
std::string blob;
dmlc::MemoryStringStream fs(&blob);
using SmartPtr = typename TestFixture::SmartPtr;
typedef std::pair<std::string, SmartPtr> Pair;
auto my1 = SmartPtr(new MyClass("@A@"));
auto my2 = SmartPtr(new MyClass("2222"));
std::vector<Pair> myclasses;
myclasses.emplace_back("a", SmartPtr(new MyClass("@A@B")));
myclasses.emplace_back("b", SmartPtr(new MyClass("2222")));
static_cast<dmlc::Stream *>(&fs)->Write<std::vector<Pair>>(myclasses);
dmlc::MemoryStringStream ofs(&blob);
std::vector<Pair> copy_myclasses;
static_cast<dmlc::Stream *>(&ofs)->Read<std::vector<Pair>>(&copy_myclasses);
EXPECT_TRUE(std::equal(myclasses.begin(), myclasses.end(),
copy_myclasses.begin(),
[](const Pair &left, const Pair &right) {
return (left.second->data_ == right.second->data_) &&
(left.first == right.first);
}));
}
TYPED_TEST(SmartPtrTest, Vector_Test2) {
std::string blob;
dmlc::MemoryStringStream fs(&blob);
using SmartPtr = typename TestFixture::SmartPtr;
auto my1 = SmartPtr(new MyClass("@A@"));
auto my2 = SmartPtr(new MyClass("2222"));
std::vector<SmartPtr> myclasses;
myclasses.emplace_back(new MyClass("@A@"));
myclasses.emplace_back(new MyClass("2222"));
static_cast<dmlc::Stream *>(&fs)->Write<std::vector<SmartPtr>>(myclasses);
dmlc::MemoryStringStream ofs(&blob);
std::vector<SmartPtr> copy_myclasses;
static_cast<dmlc::Stream *>(&ofs)->Read<std::vector<SmartPtr>>(
&copy_myclasses);
EXPECT_TRUE(std::equal(myclasses.begin(), myclasses.end(),
copy_myclasses.begin(),
[](const SmartPtr &left, const SmartPtr &right) {
return left->data_ == right->data_;
}));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment