test_serialize.cc 3.35 KB
Newer Older
1
#include <dgl/graph_serializer.h>
2
#include <dgl/immutable_graph.h>
3
4
#include <dmlc/memory_io.h>
#include <gtest/gtest.h>
5

6
7
#include <algorithm>
#include <iostream>
8
#include <memory>
9
#include <vector>
10

11
#include "../../src/graph/heterograph.h"
12
13
14
15
16
17
18
#include "../../src/graph/unit_graph.h"
#include "./common.h"

using namespace dgl;
using namespace dgl::aten;
using namespace dmlc;

19
TEST(Serialize, UnitGraph_COO) {
20
21
22
  aten::CSRMatrix csr_matrix;
  auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
  auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
23
  auto mg = std::dynamic_pointer_cast<UnitGraph>(
24
      dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, COO_CODE));
25

26
27
28
  std::string blob;
  dmlc::MemoryStringStream ifs(&blob);

29
  static_cast<dmlc::Stream *>(&ifs)->Write(mg);
30
31

  dmlc::MemoryStringStream ofs(&blob);
32
33
  auto ug2 = Serializer::make_shared<UnitGraph>();
  static_cast<dmlc::Stream *>(&ofs)->Read(&ug2);
34
35
  EXPECT_EQ(ug2->NumVertices(0), 9);
  EXPECT_EQ(ug2->NumVertices(1), 8);
36
  EXPECT_EQ(ug2->NumEdges(0), 4);
37
38
39
40
  EXPECT_EQ(ug2->FindEdge(0, 1).first, 2);
  EXPECT_EQ(ug2->FindEdge(0, 1).second, 6);
}

41
42
TEST(Serialize, UnitGraph_CSR) {
  aten::CSRMatrix csr_matrix;
43
44
  auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
  auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
45
46
47
  auto coo_g = std::dynamic_pointer_cast<UnitGraph>(
      dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst));
  auto csr_g =
48
      std::dynamic_pointer_cast<UnitGraph>(coo_g->GetGraphInFormat(CSR_CODE));
49
50
51
52

  std::string blob;
  dmlc::MemoryStringStream ifs(&blob);

53
  static_cast<dmlc::Stream *>(&ifs)->Write(csr_g);
54

55
56
57
58
59
60
61
62
63
64
  dmlc::MemoryStringStream ofs(&blob);
  auto ug2 = Serializer::make_shared<UnitGraph>();
  static_cast<dmlc::Stream *>(&ofs)->Read(&ug2);
  // Query operation is not supported on CSR, how to check it?
}

TEST(Serialize, ImmutableGraph) {
  auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
  auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
  auto gptr = ImmutableGraph::CreateFromCOO(10, src, dst);
65
66
67
  std::string blob;
  dmlc::MemoryStringStream ifs(&blob);

68
  static_cast<dmlc::Stream *>(&ifs)->Write(gptr);
69
70

  dmlc::MemoryStringStream ofs(&blob);
71
72
  auto rptr_read = dgl::Serializer::make_shared<ImmutableGraph>();
  static_cast<dmlc::Stream *>(&ofs)->Read(&rptr_read);
73
74
75
76
77
78
  EXPECT_EQ(rptr_read->NumEdges(), 4);
  EXPECT_EQ(rptr_read->NumVertices(), 10);
  EXPECT_EQ(rptr_read->FindEdge(2).first, 5);
  EXPECT_EQ(rptr_read->FindEdge(2).second, 2);
}

79
TEST(Serialize, HeteroGraph) {
80
81
82
  auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
  auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
  auto mg1 = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);
Jinjing Zhou's avatar
Jinjing Zhou committed
83
84
  src = VecToIdArray<int64_t>({6, 2, 5, 1, 8});
  dst = VecToIdArray<int64_t>({5, 2, 4, 8, 0});
85
86
87
88
89
90
  auto mg2 = dgl::UnitGraph::CreateFromCOO(1, 9, 9, src, dst);
  std::vector<HeteroGraphPtr> relgraphs;
  relgraphs.push_back(mg1);
  relgraphs.push_back(mg2);
  src = VecToIdArray<int64_t>({0, 0});
  dst = VecToIdArray<int64_t>({1, 0});
91
92
  auto meta_gptr = ImmutableGraph::CreateFromCOO(3, src, dst);
  auto hrptr = std::make_shared<HeteroGraph>(meta_gptr, relgraphs);
93
94
95

  std::string blob;
  dmlc::MemoryStringStream ifs(&blob);
96
  static_cast<dmlc::Stream *>(&ifs)->Write(hrptr);
97
98

  dmlc::MemoryStringStream ofs(&blob);
99
100
  auto gptr = dgl::Serializer::make_shared<HeteroGraph>();
  static_cast<dmlc::Stream *>(&ofs)->Read(&gptr);
101
102
  EXPECT_EQ(gptr->NumVertices(0), 9);
  EXPECT_EQ(gptr->NumVertices(1), 8);
103
}