test_smart_ptr_serialize.cc 3.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <dgl/runtime/serializer.h>
#include <dgl/runtime/smart_ptr_serializer.h>
#include <dmlc/io.h>
#include <dmlc/logging.h>
#include <dmlc/memory_io.h>
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
#include <cstring>
#include <iostream>
#include <sstream>
#include <unordered_map>

using namespace std;

class MyClass {
 public:
  MyClass() {}
  MyClass(std::string data) : data_(data) {}
  inline void Save(dmlc::Stream *strm) const { strm->Write(this->data_); }
  inline bool Load(dmlc::Stream *strm) { return strm->Read(&data_); }
  inline bool operator==(const MyClass &other) const {
    return data_ == other.data_;
  }

 public:
  std::string data_;
};
// need to declare the traits property of my class to dmlc
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, MyClass, true);
}

template <typename T>
class SmartPtrTest : public ::testing::Test {
 public:
  typedef T SmartPtr;
};

using SmartPtrTypes =
    ::testing::Types<std::shared_ptr<MyClass>, std::unique_ptr<MyClass>>;
TYPED_TEST_SUITE(SmartPtrTest, SmartPtrTypes);

TYPED_TEST(SmartPtrTest, Obj_Test) {
  std::string blob;
  dmlc::MemoryStringStream fs(&blob);
  using SmartPtr = typename TestFixture::SmartPtr;
  auto myc = SmartPtr(new MyClass("1111"));
  { static_cast<dmlc::Stream *>(&fs)->Write(myc); }
  fs.Seek(0);
  auto copy_data = SmartPtr(new MyClass());
  CHECK(static_cast<dmlc::Stream *>(&fs)->Read(&copy_data));

  EXPECT_EQ(myc->data_, copy_data->data_);
}

TYPED_TEST(SmartPtrTest, Vector_Test1) {
  std::string blob;
  dmlc::MemoryStringStream fs(&blob);
  using SmartPtr = typename TestFixture::SmartPtr;
  typedef std::pair<std::string, SmartPtr> Pair;
  auto my1 = SmartPtr(new MyClass("@A@"));
  auto my2 = SmartPtr(new MyClass("2222"));

  std::vector<Pair> myclasses;
  myclasses.emplace_back("a", SmartPtr(new MyClass("@A@B")));
  myclasses.emplace_back("b", SmartPtr(new MyClass("2222")));
  static_cast<dmlc::Stream *>(&fs)->Write<std::vector<Pair>>(myclasses);

  dmlc::MemoryStringStream ofs(&blob);
  std::vector<Pair> copy_myclasses;
  static_cast<dmlc::Stream *>(&ofs)->Read<std::vector<Pair>>(&copy_myclasses);

  EXPECT_TRUE(std::equal(myclasses.begin(), myclasses.end(),
                         copy_myclasses.begin(),
                         [](const Pair &left, const Pair &right) {
                           return (left.second->data_ == right.second->data_) &&
                                  (left.first == right.first);
                         }));
}

TYPED_TEST(SmartPtrTest, Vector_Test2) {
  std::string blob;
  dmlc::MemoryStringStream fs(&blob);
  using SmartPtr = typename TestFixture::SmartPtr;
  auto my1 = SmartPtr(new MyClass("@A@"));
  auto my2 = SmartPtr(new MyClass("2222"));

  std::vector<SmartPtr> myclasses;
  myclasses.emplace_back(new MyClass("@A@"));
  myclasses.emplace_back(new MyClass("2222"));
  static_cast<dmlc::Stream *>(&fs)->Write<std::vector<SmartPtr>>(myclasses);

  dmlc::MemoryStringStream ofs(&blob);
  std::vector<SmartPtr> copy_myclasses;
  static_cast<dmlc::Stream *>(&ofs)->Read<std::vector<SmartPtr>>(
      &copy_myclasses);

  EXPECT_TRUE(std::equal(myclasses.begin(), myclasses.end(),
                         copy_myclasses.begin(),
                         [](const SmartPtr &left, const SmartPtr &right) {
                           return left->data_ == right->data_;
                         }));
}