"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3139d39fa73baf1fcddb4d9feea58b5f9cfd86e4"
Unverified Commit 23893bb0 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Data format] Serialization for Immutable Graph and HeteroGraph (#1254)



* graph format

* fix lint

* lint

* fix

* unit test

* lint

* add magic num

* move serialize out of struct

* lint

* serialize

* trigger CI

* fix lint

* lint
Co-authored-by: default avatarzhoujinjing09 <zhoujinjing09@users.noreply.github.com>
parent ffe58983
...@@ -958,6 +958,12 @@ class ImmutableGraph: public GraphInterface { ...@@ -958,6 +958,12 @@ class ImmutableGraph: public GraphInterface {
*/ */
ImmutableGraphPtr Reverse() const; ImmutableGraphPtr Reverse() const;
/*! \return Load HeteroGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs);
/*! \return Save HeteroGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const;
void SortCSR() { void SortCSR() {
GetInCSR()->SortCSR(); GetInCSR()->SortCSR();
GetOutCSR()->SortCSR(); GetOutCSR()->SortCSR();
...@@ -1028,4 +1034,8 @@ CSR::CSR(int64_t num_vertices, int64_t num_edges, ...@@ -1028,4 +1034,8 @@ CSR::CSR(int64_t num_vertices, int64_t num_edges,
} // namespace dgl } // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::ImmutableGraph, true);
} // namespace dmlc
#endif // DGL_IMMUTABLE_GRAPH_H_ #endif // DGL_IMMUTABLE_GRAPH_H_
/*!
* Copyright (c) 2018 by Contributors
* \file graph/graph_serializer.cc
* \brief DGL serializer APIs
*/
#pragma once
#include <dgl/immutable_graph.h>
#include "heterograph.h"
#include "unit_graph.h"
namespace dgl {
class Serializer {
public:
static HeteroGraph* EmptyHeteroGraph() { return new HeteroGraph(); }
static ImmutableGraph* EmptyImmutableGraph() {
return new ImmutableGraph(static_cast<COOPtr>(nullptr));
}
static UnitGraph* EmptyUnitGraph() {
return UnitGraph::EmptyGraph();
}
};
} // namespace dgl
...@@ -4,14 +4,18 @@ ...@@ -4,14 +4,18 @@
* \brief Heterograph implementation * \brief Heterograph implementation
*/ */
#include "./heterograph.h" #include "./heterograph.h"
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/immutable_graph.h>
#include <vector> #include <vector>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./unit_graph.h" #include "./unit_graph.h"
#include "graph_serializer.h"
// TODO(BarclayII): currently CompactGraphs depend on IdHashMap implementation which // TODO(BarclayII): currently CompactGraphs depend on IdHashMap implementation which
// only works on CPU. Should fix later to make it device agnostic. // only works on CPU. Should fix later to make it device agnostic.
#include "../array/cpu/array_utils.h" #include "../array/cpu/array_utils.h"
...@@ -21,6 +25,8 @@ using namespace dgl::runtime; ...@@ -21,6 +25,8 @@ using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace { namespace {
using dgl::ImmutableGraph;
HeteroSubgraph EdgeSubgraphPreserveNodes( HeteroSubgraph EdgeSubgraphPreserveNodes(
const HeteroGraph* hg, const std::vector<IdArray>& eids) { const HeteroGraph* hg, const std::vector<IdArray>& eids) {
CHECK_EQ(eids.size(), hg->NumEdgeTypes()) CHECK_EQ(eids.size(), hg->NumEdgeTypes())
...@@ -494,6 +500,39 @@ CompactGraphs(const std::vector<HeteroGraphPtr> &graphs) { ...@@ -494,6 +500,39 @@ CompactGraphs(const std::vector<HeteroGraphPtr> &graphs) {
return result; return result;
} }
constexpr uint64_t kDGLSerialize_HeteroGraph = 0xDD589FBE35224ABF;
bool HeteroGraph::Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_HeteroGraph) << "Invalid HeteroGraph Data";
auto meta_grptr = new ImmutableGraph(static_cast<COOPtr>(nullptr));
CHECK(fs->Read(meta_grptr)) << "Invalid Immutable Graph Data";
uint64_t num_relation_graphs;
CHECK(fs->Read(&num_relation_graphs)) << "Invalid num of relation graphs";
std::vector<HeteroGraphPtr> relgraphs;
for (size_t i = 0; i < num_relation_graphs; ++i) {
UnitGraph* ugptr = Serializer::EmptyUnitGraph();
CHECK(fs->Read(ugptr)) << "Invalid UnitGraph Data";
relgraphs.emplace_back(dynamic_cast<BaseHeteroGraph*>(ugptr));
}
HeteroGraph* hgptr = new HeteroGraph(GraphPtr(meta_grptr), relgraphs);
*this = *hgptr;
return true;
}
void HeteroGraph::Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_HeteroGraph);
auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
ImmutableGraph* meta_rptr = meta_graph_ptr.get();
fs->Write(*meta_rptr);
fs->Write(static_cast<uint64_t>(relation_graphs_.size()));
for (auto hptr : relation_graphs_) {
auto rptr = dynamic_cast<UnitGraph*>(hptr.get());
fs->Write(*rptr);
}
}
///////////////////////// C APIs ///////////////////////// ///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
......
...@@ -170,7 +170,20 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -170,7 +170,20 @@ class HeteroGraph : public BaseHeteroGraph {
FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override; FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;
/*! \return Load HeteroGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs);
/*! \return Save HeteroGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const;
private: private:
// To create empty class
friend class Serializer;
// Empty Constructor, only for serializer
HeteroGraph() : BaseHeteroGraph(static_cast<GraphPtr>(nullptr)) {}
/*! \brief A map from edge type to unit graph */ /*! \brief A map from edge type to unit graph */
std::vector<HeteroGraphPtr> relation_graphs_; std::vector<HeteroGraphPtr> relation_graphs_;
...@@ -183,4 +196,10 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -183,4 +196,10 @@ class HeteroGraph : public BaseHeteroGraph {
} // namespace dgl } // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::HeteroGraph, true);
} // namespace dmlc
#endif // DGL_GRAPH_HETEROGRAPH_H_ #endif // DGL_GRAPH_HETEROGRAPH_H_
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <string.h> #include <string.h>
#include <bitset> #include <bitset>
#include <numeric> #include <numeric>
...@@ -634,6 +636,28 @@ ImmutableGraphPtr ImmutableGraph::Reverse() const { ...@@ -634,6 +636,28 @@ ImmutableGraphPtr ImmutableGraph::Reverse() const {
} }
} }
constexpr uint64_t kDGLSerialize_ImGraph = 0xDD3c5FFE20046ABF;
/*! \return Load HeteroGraph from stream, using OutCSR Matrix*/
bool ImmutableGraph::Load(dmlc::Stream *fs) {
uint64_t magicNum;
aten::CSRMatrix out_csr_matrix;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_ImGraph) << "Invalid ImmutableGraph Data";
CHECK(fs->Read(&out_csr_matrix)) << "Invalid csr matrix";
CSRPtr csr(new CSR(out_csr_matrix.indptr, out_csr_matrix.indices,
out_csr_matrix.data));
auto g = new ImmutableGraph(nullptr, csr);
*this = *g;
return true;
}
/*! \return Save HeteroGraph to stream, using OutCSR Matrix */
void ImmutableGraph::Save(dmlc::Stream *fs) const {
fs->Write(kDGLSerialize_ImGraph);
fs->Write(GetOutCSR()->ToCSRMatrix());
}
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
namespace dgl { namespace dgl {
namespace { namespace {
using namespace dgl::aten;
// create metagraph of one node type // create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() { inline GraphPtr CreateUnitGraphMetaGraph1() {
// a self-loop edge 0->0 // a self-loop edge 0->0
...@@ -1152,8 +1155,17 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const { ...@@ -1152,8 +1155,17 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
return SparseFormat::COO; return SparseFormat::COO;
} }
UnitGraph* UnitGraph::EmptyGraph() {
auto src = NewIdArray(0);
auto dst = NewIdArray(0);
auto mg = CreateUnitGraphMetaGraph(1);
COOPtr coo(new COO(mg, 0, 0, src, dst));
return new UnitGraph(mg, nullptr, nullptr, coo);
}
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127; constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;
// Using OurCSR
bool UnitGraph::Load(dmlc::Stream* fs) { bool UnitGraph::Load(dmlc::Stream* fs) {
uint64_t magicNum; uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number"; CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
...@@ -1164,27 +1176,25 @@ bool UnitGraph::Load(dmlc::Stream* fs) { ...@@ -1164,27 +1176,25 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
CHECK(fs->Read(&num_dst)) << "Invalid num_dst"; CHECK(fs->Read(&num_dst)) << "Invalid num_dst";
aten::CSRMatrix csr_matrix; aten::CSRMatrix csr_matrix;
CHECK(fs->Read(&csr_matrix)) << "Invalid csr_matrix"; CHECK(fs->Read(&csr_matrix)) << "Invalid csr_matrix";
SparseFormat restrict_format;
CHECK(fs->Read(&restrict_format)) << "Invalid restrict_format";
auto mg = CreateUnitGraphMetaGraph(num_vtypes); auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csr(new CSR(mg, num_src, num_dst, csr_matrix.indptr, csr_matrix.indices, csr_matrix.data)); CSRPtr csr(new CSR(mg, num_src, num_dst, csr_matrix.indptr,
*this = UnitGraph(mg, csr, nullptr, nullptr); csr_matrix.indices, csr_matrix.data));
*this = UnitGraph(mg, nullptr, csr, nullptr);
return true; return true;
} }
// Using Out CSR
void UnitGraph::Save(dmlc::Stream* fs) const { void UnitGraph::Save(dmlc::Stream* fs) const {
// Following CreateFromCSR signature // Following CreateFromCSR signature
aten::CSRMatrix csr_matrix = GetInCSRMatrix(); aten::CSRMatrix csr_matrix = GetOutCSRMatrix();
uint64_t num_vtypes = NumVertexTypes(); uint64_t num_vtypes = NumVertexTypes();
uint64_t num_src = NumVertices(SrcType()); uint64_t num_src = NumVertices(SrcType());
uint64_t num_dst = NumVertices(DstType()); uint64_t num_dst = NumVertices(DstType());
SparseFormat restrict_format = restrict_format_;
fs->Write(kDGLSerialize_UnitGraphMagic); fs->Write(kDGLSerialize_UnitGraphMagic);
fs->Write(num_vtypes); fs->Write(num_vtypes);
fs->Write(num_src); fs->Write(num_src);
fs->Write(num_dst); fs->Write(num_dst);
fs->Write(csr_matrix); fs->Write(csr_matrix);
fs->Write(restrict_format);
} }
} // namespace dgl } // namespace dgl
...@@ -186,6 +186,8 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -186,6 +186,8 @@ class UnitGraph : public BaseHeteroGraph {
void Save(dmlc::Stream* fs) const; void Save(dmlc::Stream* fs) const;
private: private:
friend class Serializer;
/*! /*!
* \brief constructor * \brief constructor
* \param metagraph metagraph * \param metagraph metagraph
...@@ -219,6 +221,9 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -219,6 +221,9 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Whether the graph is hypersparse */ /*! \return Whether the graph is hypersparse */
bool IsHypersparse() const; bool IsHypersparse() const;
// Empty Graph for Serializer Usgae
static UnitGraph* EmptyGraph();
// Graph stored in different format. We use an on-demand strategy: the format is // Graph stored in different format. We use an on-demand strategy: the format is
// only materialized if the operation that suitable for it is invoked. // only materialized if the operation that suitable for it is invoked.
/*! \brief CSR graph that stores reverse edges */ /*! \brief CSR graph that stores reverse edges */
......
#include <dgl/immutable_graph.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "../../src/graph/graph_serializer.h"
#include "../../src/graph/heterograph.h"
#include "../../src/graph/unit_graph.h" #include "../../src/graph/unit_graph.h"
#include "./common.h" #include "./common.h"
...@@ -22,13 +25,58 @@ TEST(Serialize, UnitGraph) { ...@@ -22,13 +25,58 @@ TEST(Serialize, UnitGraph) {
static_cast<dmlc::Stream*>(&ifs)->Write<UnitGraph>(*ug); static_cast<dmlc::Stream*>(&ifs)->Write<UnitGraph>(*ug);
dmlc::MemoryStringStream ofs(&blob); dmlc::MemoryStringStream ofs(&blob);
src = NewIdArray(0); UnitGraph* ug2 = Serializer::EmptyUnitGraph();
dst = NewIdArray(0);
auto mg2 = dgl::UnitGraph::CreateFromCOO(
1, 0, 0, src, dst); // Any way to construct Empty UnitGraph?
UnitGraph* ug2 = dynamic_cast<UnitGraph*>(mg2.get());
static_cast<dmlc::Stream*>(&ofs)->Read(ug2); static_cast<dmlc::Stream*>(&ofs)->Read(ug2);
EXPECT_EQ(ug2->NumVertices(0), 8); EXPECT_EQ(ug2->NumVertices(0), 9);
EXPECT_EQ(ug2->NumVertices(1), 9); EXPECT_EQ(ug2->NumVertices(1), 8);
EXPECT_EQ(ug2->NumEdges(0), 4); EXPECT_EQ(ug2->NumEdges(0), 4);
EXPECT_EQ(ug2->FindEdge(0, 1).first, 2);
EXPECT_EQ(ug2->FindEdge(0, 1).second, 6);
}
TEST(Serialize, ImmutableGraph) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto gptr = ImmutableGraph::CreateFromCOO(10, src, dst);
ImmutableGraph* rptr = gptr.get();
std::string blob;
dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream*>(&ifs)->Write(*rptr);
dmlc::MemoryStringStream ofs(&blob);
ImmutableGraph* rptr_read = new ImmutableGraph(static_cast<COOPtr>(nullptr));
static_cast<dmlc::Stream*>(&ofs)->Read(rptr_read);
EXPECT_EQ(rptr_read->NumEdges(), 4);
EXPECT_EQ(rptr_read->NumVertices(), 10);
EXPECT_EQ(rptr_read->FindEdge(2).first, 5);
EXPECT_EQ(rptr_read->FindEdge(2).second, 2);
}
TEST(Serialize, HeteroGraph) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg1 = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);
src = VecToIdArray<int64_t>({6, 2, 5, 1, 9});
dst = VecToIdArray<int64_t>({5, 2, 4, 9, 0});
auto mg2 = dgl::UnitGraph::CreateFromCOO(1, 9, 9, src, dst);
std::vector<HeteroGraphPtr> relgraphs;
relgraphs.push_back(mg1);
relgraphs.push_back(mg2);
src = VecToIdArray<int64_t>({0, 0});
dst = VecToIdArray<int64_t>({1, 0});
auto meta_gptr = ImmutableGraph::CreateFromCOO(2, src, dst);
HeteroGraph* hrptr = new HeteroGraph(meta_gptr, relgraphs);
std::string blob;
dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream*>(&ifs)->Write(*hrptr);
dmlc::MemoryStringStream ofs(&blob);
HeteroGraph* gptr = dgl::Serializer::EmptyHeteroGraph();
static_cast<dmlc::Stream*>(&ofs)->Read(gptr);
EXPECT_EQ(gptr->NumVertices(0), 9);
EXPECT_EQ(gptr->NumVertices(1), 8);
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment