"docs/vscode:/vscode.git/clone" did not exist on "982f20284ba55f20f4c65d0968fe2b4da3e50dd9"
Unverified Commit 23893bb0 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Data format] Serialization for Immutable Graph and HeteroGraph (#1254)



* graph format

* fix lint

* lint

* fix

* unit test

* lint

* add magic num

* move serialize out of struct

* lint

* serialize

* trigger CI

* fix lint

* lint
Co-authored-by: default avatarzhoujinjing09 <zhoujinjing09@users.noreply.github.com>
parent ffe58983
......@@ -958,6 +958,12 @@ class ImmutableGraph: public GraphInterface {
*/
ImmutableGraphPtr Reverse() const;
/*! \return Load HeteroGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs);
/*! \return Save HeteroGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const;
void SortCSR() {
GetInCSR()->SortCSR();
GetOutCSR()->SortCSR();
......@@ -1028,4 +1034,8 @@ CSR::CSR(int64_t num_vertices, int64_t num_edges,
} // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::ImmutableGraph, true);
} // namespace dmlc
#endif // DGL_IMMUTABLE_GRAPH_H_
/*!
* Copyright (c) 2018 by Contributors
* \file graph/graph_serializer.cc
* \brief DGL serializer APIs
*/
#pragma once
#include <dgl/immutable_graph.h>
#include "heterograph.h"
#include "unit_graph.h"
namespace dgl {
class Serializer {
public:
static HeteroGraph* EmptyHeteroGraph() { return new HeteroGraph(); }
static ImmutableGraph* EmptyImmutableGraph() {
return new ImmutableGraph(static_cast<COOPtr>(nullptr));
}
static UnitGraph* EmptyUnitGraph() {
return UnitGraph::EmptyGraph();
}
};
} // namespace dgl
......@@ -4,14 +4,18 @@
* \brief Heterograph implementation
*/
#include "./heterograph.h"
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/immutable_graph.h>
#include <vector>
#include <tuple>
#include <utility>
#include "../c_api_common.h"
#include "./unit_graph.h"
#include "graph_serializer.h"
// TODO(BarclayII): currently CompactGraphs depend on IdHashMap implementation which
// only works on CPU. Should fix later to make it device agnostic.
#include "../array/cpu/array_utils.h"
......@@ -21,6 +25,8 @@ using namespace dgl::runtime;
namespace dgl {
namespace {
using dgl::ImmutableGraph;
HeteroSubgraph EdgeSubgraphPreserveNodes(
const HeteroGraph* hg, const std::vector<IdArray>& eids) {
CHECK_EQ(eids.size(), hg->NumEdgeTypes())
......@@ -494,6 +500,39 @@ CompactGraphs(const std::vector<HeteroGraphPtr> &graphs) {
return result;
}
constexpr uint64_t kDGLSerialize_HeteroGraph = 0xDD589FBE35224ABF;
bool HeteroGraph::Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_HeteroGraph) << "Invalid HeteroGraph Data";
auto meta_grptr = new ImmutableGraph(static_cast<COOPtr>(nullptr));
CHECK(fs->Read(meta_grptr)) << "Invalid Immutable Graph Data";
uint64_t num_relation_graphs;
CHECK(fs->Read(&num_relation_graphs)) << "Invalid num of relation graphs";
std::vector<HeteroGraphPtr> relgraphs;
for (size_t i = 0; i < num_relation_graphs; ++i) {
UnitGraph* ugptr = Serializer::EmptyUnitGraph();
CHECK(fs->Read(ugptr)) << "Invalid UnitGraph Data";
relgraphs.emplace_back(dynamic_cast<BaseHeteroGraph*>(ugptr));
}
HeteroGraph* hgptr = new HeteroGraph(GraphPtr(meta_grptr), relgraphs);
*this = *hgptr;
return true;
}
void HeteroGraph::Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_HeteroGraph);
auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
ImmutableGraph* meta_rptr = meta_graph_ptr.get();
fs->Write(*meta_rptr);
fs->Write(static_cast<uint64_t>(relation_graphs_.size()));
for (auto hptr : relation_graphs_) {
auto rptr = dynamic_cast<UnitGraph*>(hptr.get());
fs->Write(*rptr);
}
}
///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
......
......@@ -170,7 +170,20 @@ class HeteroGraph : public BaseHeteroGraph {
FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;
/*! \return Load HeteroGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs);
/*! \return Save HeteroGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const;
private:
// To create empty class
friend class Serializer;
// Empty Constructor, only for serializer
HeteroGraph() : BaseHeteroGraph(static_cast<GraphPtr>(nullptr)) {}
/*! \brief A map from edge type to unit graph */
std::vector<HeteroGraphPtr> relation_graphs_;
......@@ -183,4 +196,10 @@ class HeteroGraph : public BaseHeteroGraph {
} // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::HeteroGraph, true);
} // namespace dmlc
#endif // DGL_GRAPH_HETEROGRAPH_H_
......@@ -6,6 +6,8 @@
#include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <string.h>
#include <bitset>
#include <numeric>
......@@ -634,6 +636,28 @@ ImmutableGraphPtr ImmutableGraph::Reverse() const {
}
}
constexpr uint64_t kDGLSerialize_ImGraph = 0xDD3c5FFE20046ABF;
/*! \return Load HeteroGraph from stream, using OutCSR Matrix*/
bool ImmutableGraph::Load(dmlc::Stream *fs) {
uint64_t magicNum;
aten::CSRMatrix out_csr_matrix;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_ImGraph) << "Invalid ImmutableGraph Data";
CHECK(fs->Read(&out_csr_matrix)) << "Invalid csr matrix";
CSRPtr csr(new CSR(out_csr_matrix.indptr, out_csr_matrix.indices,
out_csr_matrix.data));
auto g = new ImmutableGraph(nullptr, csr);
*this = *g;
return true;
}
/*! \return Save HeteroGraph to stream, using OutCSR Matrix */
void ImmutableGraph::Save(dmlc::Stream *fs) const {
fs->Write(kDGLSerialize_ImGraph);
fs->Write(GetOutCSR()->ToCSRMatrix());
}
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
......
......@@ -14,6 +14,9 @@
namespace dgl {
namespace {
using namespace dgl::aten;
// create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() {
// a self-loop edge 0->0
......@@ -1152,8 +1155,17 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
return SparseFormat::COO;
}
UnitGraph* UnitGraph::EmptyGraph() {
auto src = NewIdArray(0);
auto dst = NewIdArray(0);
auto mg = CreateUnitGraphMetaGraph(1);
COOPtr coo(new COO(mg, 0, 0, src, dst));
return new UnitGraph(mg, nullptr, nullptr, coo);
}
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;
// Using OurCSR
bool UnitGraph::Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
......@@ -1164,27 +1176,25 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
CHECK(fs->Read(&num_dst)) << "Invalid num_dst";
aten::CSRMatrix csr_matrix;
CHECK(fs->Read(&csr_matrix)) << "Invalid csr_matrix";
SparseFormat restrict_format;
CHECK(fs->Read(&restrict_format)) << "Invalid restrict_format";
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csr(new CSR(mg, num_src, num_dst, csr_matrix.indptr, csr_matrix.indices, csr_matrix.data));
*this = UnitGraph(mg, csr, nullptr, nullptr);
CSRPtr csr(new CSR(mg, num_src, num_dst, csr_matrix.indptr,
csr_matrix.indices, csr_matrix.data));
*this = UnitGraph(mg, nullptr, csr, nullptr);
return true;
}
// Using Out CSR
void UnitGraph::Save(dmlc::Stream* fs) const {
// Following CreateFromCSR signature
aten::CSRMatrix csr_matrix = GetInCSRMatrix();
aten::CSRMatrix csr_matrix = GetOutCSRMatrix();
uint64_t num_vtypes = NumVertexTypes();
uint64_t num_src = NumVertices(SrcType());
uint64_t num_dst = NumVertices(DstType());
SparseFormat restrict_format = restrict_format_;
fs->Write(kDGLSerialize_UnitGraphMagic);
fs->Write(num_vtypes);
fs->Write(num_src);
fs->Write(num_dst);
fs->Write(csr_matrix);
fs->Write(restrict_format);
}
} // namespace dgl
......@@ -186,6 +186,8 @@ class UnitGraph : public BaseHeteroGraph {
void Save(dmlc::Stream* fs) const;
private:
friend class Serializer;
/*!
* \brief constructor
* \param metagraph metagraph
......@@ -219,6 +221,9 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Whether the graph is hypersparse */
bool IsHypersparse() const;
// Empty Graph for Serializer Usgae
static UnitGraph* EmptyGraph();
// Graph stored in different format. We use an on-demand strategy: the format is
// only materialized if the operation that suitable for it is invoked.
/*! \brief CSR graph that stores reverse edges */
......
#include <dgl/immutable_graph.h>
#include <dmlc/memory_io.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <vector>
#include "../../src/graph/graph_serializer.h"
#include "../../src/graph/heterograph.h"
#include "../../src/graph/unit_graph.h"
#include "./common.h"
......@@ -22,13 +25,58 @@ TEST(Serialize, UnitGraph) {
static_cast<dmlc::Stream*>(&ifs)->Write<UnitGraph>(*ug);
dmlc::MemoryStringStream ofs(&blob);
src = NewIdArray(0);
dst = NewIdArray(0);
auto mg2 = dgl::UnitGraph::CreateFromCOO(
1, 0, 0, src, dst); // Any way to construct Empty UnitGraph?
UnitGraph* ug2 = dynamic_cast<UnitGraph*>(mg2.get());
UnitGraph* ug2 = Serializer::EmptyUnitGraph();
static_cast<dmlc::Stream*>(&ofs)->Read(ug2);
EXPECT_EQ(ug2->NumVertices(0), 8);
EXPECT_EQ(ug2->NumVertices(1), 9);
EXPECT_EQ(ug2->NumVertices(0), 9);
EXPECT_EQ(ug2->NumVertices(1), 8);
EXPECT_EQ(ug2->NumEdges(0), 4);
EXPECT_EQ(ug2->FindEdge(0, 1).first, 2);
EXPECT_EQ(ug2->FindEdge(0, 1).second, 6);
}
TEST(Serialize, ImmutableGraph) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto gptr = ImmutableGraph::CreateFromCOO(10, src, dst);
ImmutableGraph* rptr = gptr.get();
std::string blob;
dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream*>(&ifs)->Write(*rptr);
dmlc::MemoryStringStream ofs(&blob);
ImmutableGraph* rptr_read = new ImmutableGraph(static_cast<COOPtr>(nullptr));
static_cast<dmlc::Stream*>(&ofs)->Read(rptr_read);
EXPECT_EQ(rptr_read->NumEdges(), 4);
EXPECT_EQ(rptr_read->NumVertices(), 10);
EXPECT_EQ(rptr_read->FindEdge(2).first, 5);
EXPECT_EQ(rptr_read->FindEdge(2).second, 2);
}
TEST(Serialize, HeteroGraph) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg1 = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);
src = VecToIdArray<int64_t>({6, 2, 5, 1, 9});
dst = VecToIdArray<int64_t>({5, 2, 4, 9, 0});
auto mg2 = dgl::UnitGraph::CreateFromCOO(1, 9, 9, src, dst);
std::vector<HeteroGraphPtr> relgraphs;
relgraphs.push_back(mg1);
relgraphs.push_back(mg2);
src = VecToIdArray<int64_t>({0, 0});
dst = VecToIdArray<int64_t>({1, 0});
auto meta_gptr = ImmutableGraph::CreateFromCOO(2, src, dst);
HeteroGraph* hrptr = new HeteroGraph(meta_gptr, relgraphs);
std::string blob;
dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream*>(&ifs)->Write(*hrptr);
dmlc::MemoryStringStream ofs(&blob);
HeteroGraph* gptr = dgl::Serializer::EmptyHeteroGraph();
static_cast<dmlc::Stream*>(&ofs)->Read(gptr);
EXPECT_EQ(gptr->NumVertices(0), 9);
EXPECT_EQ(gptr->NumVertices(1), 8);
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment