test_serialize.cc 2.8 KB
Newer Older
1
#include <dgl/immutable_graph.h>
2
3
4
5
6
#include <dmlc/memory_io.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <vector>
7
8
#include "../../src/graph/graph_serializer.h"
#include "../../src/graph/heterograph.h"
9
10
11
12
13
14
15
#include "../../src/graph/unit_graph.h"
#include "./common.h"

using namespace dgl;
using namespace dgl::aten;
using namespace dmlc;

16
TEST(Serialize, DISABLED_UnitGraph) {
17
18
19
20
21
22
23
24
25
26
27
  aten::CSRMatrix csr_matrix;
  auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
  auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
  auto mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);
  UnitGraph* ug = dynamic_cast<UnitGraph*>(mg.get());
  std::string blob;
  dmlc::MemoryStringStream ifs(&blob);

  static_cast<dmlc::Stream*>(&ifs)->Write<UnitGraph>(*ug);

  dmlc::MemoryStringStream ofs(&blob);
28
  UnitGraph* ug2 = Serializer::EmptyUnitGraph();
29
  static_cast<dmlc::Stream*>(&ofs)->Read(ug2);
30
31
  EXPECT_EQ(ug2->NumVertices(0), 9);
  EXPECT_EQ(ug2->NumVertices(1), 8);
32
  EXPECT_EQ(ug2->NumEdges(0), 4);
33
34
  EXPECT_EQ(ug2->FindEdge(0, 1).first, 2);
  EXPECT_EQ(ug2->FindEdge(0, 1).second, 6);
35
  delete ug2;
36
37
}

38
TEST(Serialize, DISABLED_ImmutableGraph) {
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
  auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
  auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
  auto gptr = ImmutableGraph::CreateFromCOO(10, src, dst);
  ImmutableGraph* rptr = gptr.get();

  std::string blob;
  dmlc::MemoryStringStream ifs(&blob);

  static_cast<dmlc::Stream*>(&ifs)->Write(*rptr);

  dmlc::MemoryStringStream ofs(&blob);
  ImmutableGraph* rptr_read = new ImmutableGraph(static_cast<COOPtr>(nullptr));
  static_cast<dmlc::Stream*>(&ofs)->Read(rptr_read);
  EXPECT_EQ(rptr_read->NumEdges(), 4);
  EXPECT_EQ(rptr_read->NumVertices(), 10);
  EXPECT_EQ(rptr_read->FindEdge(2).first, 5);
  EXPECT_EQ(rptr_read->FindEdge(2).second, 2);
56
  delete rptr_read;
57
58
}

59
TEST(Serialize, DISABLED_HeteroGraph) {
60
61
62
  auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
  auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
  auto mg1 = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);
Jinjing Zhou's avatar
Jinjing Zhou committed
63
64
  src = VecToIdArray<int64_t>({6, 2, 5, 1, 8});
  dst = VecToIdArray<int64_t>({5, 2, 4, 8, 0});
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  auto mg2 = dgl::UnitGraph::CreateFromCOO(1, 9, 9, src, dst);
  std::vector<HeteroGraphPtr> relgraphs;
  relgraphs.push_back(mg1);
  relgraphs.push_back(mg2);
  src = VecToIdArray<int64_t>({0, 0});
  dst = VecToIdArray<int64_t>({1, 0});
  auto meta_gptr = ImmutableGraph::CreateFromCOO(2, src, dst);
  HeteroGraph* hrptr = new HeteroGraph(meta_gptr, relgraphs);

  std::string blob;
  dmlc::MemoryStringStream ifs(&blob);
  static_cast<dmlc::Stream*>(&ifs)->Write(*hrptr);

  dmlc::MemoryStringStream ofs(&blob);
  HeteroGraph* gptr = dgl::Serializer::EmptyHeteroGraph();
  static_cast<dmlc::Stream*>(&ofs)->Read(gptr);
  EXPECT_EQ(gptr->NumVertices(0), 9);
  EXPECT_EQ(gptr->NumVertices(1), 8);
83
84
85
  delete hrptr;
  delete gptr;
}