Unverified Commit 9caff617 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Feature] More serialization support (#1295)



* fix script

* t

* fix weird bugs

* fix

* fix

* upload

* fix

* fix

* lint

* fix

* tmp

* fix serialization

* fix

* fix lint

* fix message

* fix

* lint

* address comment

* fix

* lint

* fix

* fix

* fix

* Remove duplicate serialization for meta graph
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 09ade2f2
...@@ -292,7 +292,10 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths); ...@@ -292,7 +292,10 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries * Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in * that have the same row, col indices. It corresponds to multigraph in
* graph terminology. * graph terminology.
*/ */
constexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127;
struct CSRMatrix { struct CSRMatrix {
/*! \brief the dense shape of the matrix */ /*! \brief the dense shape of the matrix */
int64_t num_rows = 0, num_cols = 0; int64_t num_rows = 0, num_cols = 0;
...@@ -305,36 +308,67 @@ struct CSRMatrix { ...@@ -305,36 +308,67 @@ struct CSRMatrix {
/*! \brief default constructor */ /*! \brief default constructor */
CSRMatrix() = default; CSRMatrix() = default;
/*! \brief constructor */ /*! \brief constructor */
CSRMatrix(int64_t nrows, int64_t ncols, CSRMatrix(int64_t nrows, int64_t ncols, IdArray parr, IdArray iarr,
IdArray parr, IdArray iarr, IdArray darr = NullArray(), IdArray darr = NullArray(), bool sorted_flag = false)
bool sorted_flag = false) : num_rows(nrows),
: num_rows(nrows), num_cols(ncols), indptr(parr), indices(iarr), num_cols(ncols),
data(darr), sorted(sorted_flag) {} indptr(parr),
indices(iarr),
data(darr),
sorted(sorted_flag) {}
/*! \brief constructor from SparseMatrix object */ /*! \brief constructor from SparseMatrix object */
explicit CSRMatrix(const SparseMatrix& spmat) explicit CSRMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows), num_cols(spmat.num_cols), : num_rows(spmat.num_rows),
indptr(spmat.indices[0]), indices(spmat.indices[1]), data(spmat.indices[2]), num_cols(spmat.num_cols),
sorted(spmat.flags[0]) {} indptr(spmat.indices[0]),
indices(spmat.indices[1]),
data(spmat.indices[2]),
sorted(spmat.flags[0]) {}
// Convert to a SparseMatrix object that can return to python. // Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const { SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::CSR), return SparseMatrix(static_cast<int32_t>(SparseFormat::CSR), num_rows,
num_rows, num_cols, num_cols, {indptr, indices, data}, {sorted});
{indptr, indices, data}, }
{sorted});
bool Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCsrMatrixMagic)
<< "Invalid CSRMatrix Data";
CHECK(fs->Read(&num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&indptr)) << "Invalid indptr";
CHECK(fs->Read(&indices)) << "Invalid indices";
CHECK(fs->Read(&data)) << "Invalid data";
CHECK(fs->Read(&sorted)) << "Invalid sorted";
return true;
}
void Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_AtenCsrMatrixMagic);
fs->Write(num_cols);
fs->Write(num_rows);
fs->Write(indptr);
fs->Write(indices);
fs->Write(data);
fs->Write(sorted);
} }
}; };
/*! /*!
* \brief Plain COO structure * \brief Plain COO structure
* *
* The data array stores integer ids for reading edge features. * The data array stores integer ids for reading edge features.
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries * Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in * that have the same row, col indices. It corresponds to multigraph in
* graph terminology. * graph terminology.
*/ */
constexpr uint64_t kDGLSerialize_AtenCooMatrixMagic = 0xDD61ffd305dff127;
// TODO(BarclayII): Graph queries on COO formats should support the case where // TODO(BarclayII): Graph queries on COO formats should support the case where
// data ordered by rows/columns instead of EID. // data ordered by rows/columns instead of EID.
struct COOMatrix { struct COOMatrix {
...@@ -351,24 +385,57 @@ struct COOMatrix { ...@@ -351,24 +385,57 @@ struct COOMatrix {
/*! \brief default constructor */ /*! \brief default constructor */
COOMatrix() = default; COOMatrix() = default;
/*! \brief constructor */ /*! \brief constructor */
COOMatrix(int64_t nrows, int64_t ncols, COOMatrix(int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,
IdArray rarr, IdArray carr, IdArray darr = NullArray(), IdArray darr = NullArray(), bool rsorted = false,
bool rsorted = false, bool csorted = false) bool csorted = false)
: num_rows(nrows), num_cols(ncols), row(rarr), col(carr), data(darr), : num_rows(nrows),
row_sorted(rsorted), col_sorted(csorted) {} num_cols(ncols),
row(rarr),
col(carr),
data(darr),
row_sorted(rsorted),
col_sorted(csorted) {}
/*! \brief constructor from SparseMatrix object */ /*! \brief constructor from SparseMatrix object */
explicit COOMatrix(const SparseMatrix& spmat) explicit COOMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows), num_cols(spmat.num_cols), : num_rows(spmat.num_rows),
row(spmat.indices[0]), col(spmat.indices[1]), data(spmat.indices[2]), num_cols(spmat.num_cols),
row_sorted(spmat.flags[0]), col_sorted(spmat.flags[1]) {} row(spmat.indices[0]),
col(spmat.indices[1]),
data(spmat.indices[2]),
row_sorted(spmat.flags[0]),
col_sorted(spmat.flags[1]) {}
// Convert to a SparseMatrix object that can return to python. // Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const { SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::COO), return SparseMatrix(static_cast<int32_t>(SparseFormat::COO), num_rows,
num_rows, num_cols, num_cols, {row, col, data}, {row_sorted, col_sorted});
{row, col, data}, }
{row_sorted, col_sorted});
bool Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCooMatrixMagic)
<< "Invalid COOMatrix Data";
CHECK(fs->Read(&num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&row)) << "Invalid row";
CHECK(fs->Read(&col)) << "Invalid col";
CHECK(fs->Read(&data)) << "Invalid data";
CHECK(fs->Read(&row_sorted)) << "Invalid row_sorted";
CHECK(fs->Read(&col_sorted)) << "Invalid col_sorted";
return true;
}
void Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_AtenCooMatrixMagic);
fs->Write(num_cols);
fs->Write(num_rows);
fs->Write(row);
fs->Write(col);
fs->Write(data);
fs->Write(row_sorted);
fs->Write(col_sorted);
} }
}; };
...@@ -915,39 +982,8 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -915,39 +982,8 @@ IdArray VecToIdArray(const std::vector<T>& vec,
} // namespace dgl } // namespace dgl
namespace dmlc { namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::aten::CSRMatrix, true);
namespace serializer { DMLC_DECLARE_TRAITS(has_saveload, dgl::aten::COOMatrix, true);
using dgl::aten::CSRMatrix;
constexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127;
template <>
struct Handler<CSRMatrix> {
inline static void Write(Stream* fs, const CSRMatrix& csr) {
fs->Write(kDGLSerialize_AtenCsrMatrixMagic);
fs->Write(csr.num_cols);
fs->Write(csr.num_rows);
fs->Write(csr.indptr);
fs->Write(csr.indices);
fs->Write(csr.data);
fs->Write(csr.sorted);
}
inline static bool Read(Stream* fs, CSRMatrix* csr) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCsrMatrixMagic)
<< "Invalid CSRMatrix Data";
CHECK(fs->Read(&csr->num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&csr->num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&csr->indptr)) << "Invalid indptr";
CHECK(fs->Read(&csr->indices)) << "Invalid indices";
CHECK(fs->Read(&csr->data)) << "Invalid data";
CHECK(fs->Read(&csr->sorted)) << "Invalid sorted";
return true;
}
};
} // namespace serializer
} // namespace dmlc } // namespace dmlc
#endif // DGL_ARRAY_H_ #endif // DGL_ARRAY_H_
...@@ -438,6 +438,9 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -438,6 +438,9 @@ class BaseHeteroGraph : public runtime::Object {
protected: protected:
/*! \brief meta graph */ /*! \brief meta graph */
GraphPtr meta_graph_; GraphPtr meta_graph_;
// empty constructor
BaseHeteroGraph(){}
}; };
// Define HeteroGraphRef // Define HeteroGraphRef
......
...@@ -4,22 +4,25 @@ ...@@ -4,22 +4,25 @@
* \brief DGL serializer APIs * \brief DGL serializer APIs
*/ */
#pragma once #ifndef DGL_GRAPH_SERIALIZER_H_
#define DGL_GRAPH_SERIALIZER_H_
#include <dgl/immutable_graph.h>
#include "heterograph.h"
#include "unit_graph.h"
#include <memory>
namespace dgl { namespace dgl {
// Util class to call the private/public empty constructor, which is needed for serialization
class Serializer { class Serializer {
public: public:
static HeteroGraph* EmptyHeteroGraph() { return new HeteroGraph(); } template <typename T>
static ImmutableGraph* EmptyImmutableGraph() { static T* new_object() {
return new ImmutableGraph(static_cast<COOPtr>(nullptr)); return new T();
} }
static UnitGraph* EmptyUnitGraph() {
return UnitGraph::EmptyGraph(); template <typename T>
static std::shared_ptr<T> make_shared() {
return std::shared_ptr<T>(new T());
} }
}; };
} // namespace dgl } // namespace dgl
#endif // DGL_GRAPH_SERIALIZER_H_
...@@ -240,6 +240,12 @@ class CSR : public GraphInterface { ...@@ -240,6 +240,12 @@ class CSR : public GraphInterface {
IdArray edge_ids() const { return adj_.data; } IdArray edge_ids() const { return adj_.data; }
/*! \return Load CSR from stream */
bool Load(dmlc::Stream *fs);
/*! \return Save CSR to stream */
void Save(dmlc::Stream* fs) const;
void SortCSR() override { void SortCSR() override {
if (adj_.sorted) if (adj_.sorted)
return; return;
...@@ -247,9 +253,10 @@ class CSR : public GraphInterface { ...@@ -247,9 +253,10 @@ class CSR : public GraphInterface {
} }
private: private:
/*! \brief prive default constructor */ friend class Serializer;
CSR() {adj_.sorted = false;}
/*! \brief private default constructor */
CSR() {adj_.sorted = false;}
// The internal CSR adjacency matrix. // The internal CSR adjacency matrix.
// The data field stores edge ids. // The data field stores edge ids.
aten::CSRMatrix adj_; aten::CSRMatrix adj_;
...@@ -957,10 +964,10 @@ class ImmutableGraph: public GraphInterface { ...@@ -957,10 +964,10 @@ class ImmutableGraph: public GraphInterface {
*/ */
ImmutableGraphPtr Reverse() const; ImmutableGraphPtr Reverse() const;
/*! \return Load HeteroGraph from stream, using CSRMatrix*/ /*! \return Load ImmutableGraph from stream, using out csr */
bool Load(dmlc::Stream* fs); bool Load(dmlc::Stream *fs);
/*! \return Save HeteroGraph to stream, using CSRMatrix */ /*! \return Save ImmutableGraph to stream, using out csr */
void Save(dmlc::Stream* fs) const; void Save(dmlc::Stream* fs) const;
void SortCSR() { void SortCSR() {
...@@ -969,6 +976,8 @@ class ImmutableGraph: public GraphInterface { ...@@ -969,6 +976,8 @@ class ImmutableGraph: public GraphInterface {
} }
protected: protected:
friend class Serializer;
/* !\brief internal default constructor */ /* !\brief internal default constructor */
ImmutableGraph() {} ImmutableGraph() {}
...@@ -1034,6 +1043,7 @@ CSR::CSR(int64_t num_vertices, int64_t num_edges, ...@@ -1034,6 +1043,7 @@ CSR::CSR(int64_t num_vertices, int64_t num_edges,
} // namespace dgl } // namespace dgl
namespace dmlc { namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::CSR, true);
DMLC_DECLARE_TRAITS(has_saveload, dgl::ImmutableGraph, true); DMLC_DECLARE_TRAITS(has_saveload, dgl::ImmutableGraph, true);
} // namespace dmlc } // namespace dmlc
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
#include "c_runtime_api.h" #include "c_runtime_api.h"
#include "smart_ptr_serializer.h"
#include "ndarray.h" #include "ndarray.h"
namespace dmlc { namespace dmlc {
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#ifndef DGL_RUNTIME_SMART_PTR_SERIALIZER_H_ #ifndef DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
#define DGL_RUNTIME_SMART_PTR_SERIALIZER_H_ #define DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
#include <dgl/graph_serializer.h>
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
...@@ -24,7 +25,7 @@ struct Handler<std::shared_ptr<T>> { ...@@ -24,7 +25,7 @@ struct Handler<std::shared_ptr<T>> {
// shared_ptr<T>(), which is holding a nullptr. Here we need to manually // shared_ptr<T>(), which is holding a nullptr. Here we need to manually
// reset to a real object for further loading // reset to a real object for further loading
if (!(*data)) { if (!(*data)) {
data->reset(new T()); data->reset(dgl::Serializer::new_object<T>());
} }
return Handler<T>::Read(strm, data->get()); return Handler<T>::Read(strm, data->get());
} }
...@@ -40,7 +41,7 @@ struct Handler<std::unique_ptr<T>> { ...@@ -40,7 +41,7 @@ struct Handler<std::unique_ptr<T>> {
// unique_ptr<T>(), which is holding a nullptr. Here we need to manually // unique_ptr<T>(), which is holding a nullptr. Here we need to manually
// reset to a real object for further loading // reset to a real object for further loading
if (!(*data)) { if (!(*data)) {
data->reset(new T()); data->reset(dgl::Serializer::new_object<T>());
} }
return Handler<T>::Read(strm, data->get()); return Handler<T>::Read(strm, data->get());
} }
......
...@@ -4,14 +4,12 @@ ...@@ -4,14 +4,12 @@
* \brief Heterograph implementation * \brief Heterograph implementation
*/ */
#include "./heterograph.h" #include "./heterograph.h"
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/graph_serializer.h>
#include <vector> #include <vector>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include "graph_serializer.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -116,7 +114,7 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& ...@@ -116,7 +114,7 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size()); CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed."; CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
// all relation graphs must have only one edge type // all relation graphs must have only one edge type
for (const auto rg : rel_graphs) { for (const auto &rg : rel_graphs) {
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type."; CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
} }
// create num verts per type // create num verts per type
...@@ -166,7 +164,7 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& ...@@ -166,7 +164,7 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
bool HeteroGraph::IsMultigraph() const { bool HeteroGraph::IsMultigraph() const {
return const_cast<HeteroGraph*>(this)->is_multigraph_.Get([this] () { return const_cast<HeteroGraph*>(this)->is_multigraph_.Get([this] () {
for (const auto hg : relation_graphs_) { for (const auto &hg : relation_graphs_) {
if (hg->IsMultigraph()) { if (hg->IsMultigraph()) {
return true; return true;
} }
...@@ -316,31 +314,20 @@ bool HeteroGraph::Load(dmlc::Stream* fs) { ...@@ -316,31 +314,20 @@ bool HeteroGraph::Load(dmlc::Stream* fs) {
uint64_t magicNum; uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number"; CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_HeteroGraph) << "Invalid HeteroGraph Data"; CHECK_EQ(magicNum, kDGLSerialize_HeteroGraph) << "Invalid HeteroGraph Data";
auto meta_grptr = new ImmutableGraph(static_cast<COOPtr>(nullptr)); auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
CHECK(fs->Read(meta_grptr)) << "Invalid Immutable Graph Data"; CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
uint64_t num_relation_graphs; meta_graph_ = meta_imgraph;
CHECK(fs->Read(&num_relation_graphs)) << "Invalid num of relation graphs"; CHECK(fs->Read(&relation_graphs_)) << "Invalid relation_graphs_";
std::vector<HeteroGraphPtr> relgraphs; CHECK(fs->Read(&num_verts_per_type_)) << "Invalid num_verts_per_type_";
for (size_t i = 0; i < num_relation_graphs; ++i) {
UnitGraph* ugptr = Serializer::EmptyUnitGraph();
CHECK(fs->Read(ugptr)) << "Invalid UnitGraph Data";
relgraphs.emplace_back(dynamic_cast<BaseHeteroGraph*>(ugptr));
}
HeteroGraph* hgptr = new HeteroGraph(GraphPtr(meta_grptr), relgraphs);
*this = *hgptr;
return true; return true;
} }
void HeteroGraph::Save(dmlc::Stream* fs) const { void HeteroGraph::Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_HeteroGraph); fs->Write(kDGLSerialize_HeteroGraph);
auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph()); auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
ImmutableGraph* meta_rptr = meta_graph_ptr.get(); fs->Write(meta_graph_ptr);
fs->Write(*meta_rptr); fs->Write(relation_graphs_);
fs->Write(static_cast<uint64_t>(relation_graphs_.size())); fs->Write(num_verts_per_type_);
for (auto hptr : relation_graphs_) {
auto rptr = dynamic_cast<UnitGraph*>(hptr.get());
fs->Write(*rptr);
}
} }
} // namespace dgl } // namespace dgl
...@@ -198,7 +198,7 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -198,7 +198,7 @@ class HeteroGraph : public BaseHeteroGraph {
friend class Serializer; friend class Serializer;
// Empty Constructor, only for serializer // Empty Constructor, only for serializer
HeteroGraph() : BaseHeteroGraph(static_cast<GraphPtr>(nullptr)) {} HeteroGraph() : BaseHeteroGraph() {}
/*! \brief A map from edge type to unit graph */ /*! \brief A map from edge type to unit graph */
std::vector<UnitGraphPtr> relation_graphs_; std::vector<UnitGraphPtr> relation_graphs_;
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
* \brief DGL immutable graph index implementation * \brief DGL immutable graph index implementation
*/ */
#include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/smart_ptr_serializer.h>
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/type_traits.h> #include <dmlc/type_traits.h>
#include <string.h> #include <string.h>
...@@ -276,6 +277,15 @@ DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const { ...@@ -276,6 +277,15 @@ DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
return DGLIdIters(eid_data + start, eid_data + end); return DGLIdIters(eid_data + start, eid_data + end);
} }
bool CSR::Load(dmlc::Stream *fs) {
fs->Read(const_cast<dgl::aten::CSRMatrix*>(&adj_));
return true;
}
void CSR::Save(dmlc::Stream *fs) const {
fs->Write(adj_);
}
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
// //
// COO graph implementation // COO graph implementation
...@@ -643,19 +653,16 @@ bool ImmutableGraph::Load(dmlc::Stream *fs) { ...@@ -643,19 +653,16 @@ bool ImmutableGraph::Load(dmlc::Stream *fs) {
uint64_t magicNum; uint64_t magicNum;
aten::CSRMatrix out_csr_matrix; aten::CSRMatrix out_csr_matrix;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number"; CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_ImGraph) << "Invalid ImmutableGraph Data"; CHECK_EQ(magicNum, kDGLSerialize_ImGraph)
CHECK(fs->Read(&out_csr_matrix)) << "Invalid csr matrix"; << "Invalid ImmutableGraph Magic Number";
CSRPtr csr(new CSR(out_csr_matrix.indptr, out_csr_matrix.indices, CHECK(fs->Read(&out_csr_)) << "Invalid csr matrix";
out_csr_matrix.data));
auto g = new ImmutableGraph(nullptr, csr);
*this = *g;
return true; return true;
} }
/*! \return Save HeteroGraph to stream, using OutCSR Matrix */ /*! \return Save HeteroGraph to stream, using OutCSR Matrix */
void ImmutableGraph::Save(dmlc::Stream *fs) const { void ImmutableGraph::Save(dmlc::Stream *fs) const {
fs->Write(kDGLSerialize_ImGraph); fs->Write(kDGLSerialize_ImGraph);
fs->Write(GetOutCSR()->ToCSRMatrix()); fs->Write(GetOutCSR());
} }
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
* \brief UnitGraph graph implementation * \brief UnitGraph graph implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/lazy.h>
#include <dgl/immutable_graph.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/immutable_graph.h>
#include <dgl/lazy.h>
#include "./unit_graph.h"
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./unit_graph.h"
namespace dgl { namespace dgl {
...@@ -401,7 +401,24 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -401,7 +401,24 @@ class UnitGraph::COO : public BaseHeteroGraph {
(NumVertices(SrcType()) > 1000000); (NumVertices(SrcType()) > 1000000);
} }
bool Load(dmlc::Stream* fs) {
auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
meta_graph_ = meta_imgraph;
CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
return true;
}
void Save(dmlc::Stream* fs) const {
auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
fs->Write(meta_graph_ptr);
fs->Write(adj_);
}
private: private:
friend class Serializer;
COO() {}
/*! \brief internal adjacency matrix. Data array is empty */ /*! \brief internal adjacency matrix. Data array is empty */
aten::COOMatrix adj_; aten::COOMatrix adj_;
...@@ -427,7 +444,6 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -427,7 +444,6 @@ class UnitGraph::CSR : public BaseHeteroGraph {
CHECK_EQ(indices->shape[0], edge_ids->shape[0]) CHECK_EQ(indices->shape[0], edge_ids->shape[0])
<< "indices and edge id arrays should have the same length"; << "indices and edge id arrays should have the same length";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
sorted_ = false;
} }
CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst, CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
...@@ -439,12 +455,10 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -439,12 +455,10 @@ class UnitGraph::CSR : public BaseHeteroGraph {
CHECK_EQ(indices->shape[0], edge_ids->shape[0]) CHECK_EQ(indices->shape[0], edge_ids->shape[0])
<< "indices and edge id arrays should have the same length"; << "indices and edge id arrays should have the same length";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
sorted_ = false;
} }
CSR(GraphPtr metagraph, const aten::CSRMatrix& csr) CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
: BaseHeteroGraph(metagraph), adj_(csr) { : BaseHeteroGraph(metagraph), adj_(csr) {
sorted_ = false;
} }
inline dgl_type_t SrcType() const { inline dgl_type_t SrcType() const {
...@@ -738,15 +752,30 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -738,15 +752,30 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return adj_; return adj_;
} }
bool Load(dmlc::Stream* fs) {
auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
meta_graph_ = meta_imgraph;
CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
return true;
}
void Save(dmlc::Stream* fs) const {
auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
fs->Write(meta_graph_ptr);
fs->Write(adj_);
}
private: private:
friend class Serializer;
CSR() {};
/*! \brief internal adjacency matrix. Data array stores edge ids */ /*! \brief internal adjacency matrix. Data array stores edge ids */
aten::CSRMatrix adj_; aten::CSRMatrix adj_;
/*! \brief multi-graph flag */ /*! \brief multi-graph flag */
Lazy<bool> is_multigraph_; Lazy<bool> is_multigraph_;
/*! \brief indicate that the edges are stored in the sorted order. */
bool sorted_;
}; };
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
...@@ -1192,17 +1221,17 @@ HeteroGraphPtr UnitGraph::GetAny() const { ...@@ -1192,17 +1221,17 @@ HeteroGraphPtr UnitGraph::GetAny() const {
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const { HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
switch (format) { switch (format) {
case SparseFormat::CSR: case SparseFormat::CSR:
return GetOutCSR(); return GetOutCSR();
case SparseFormat::CSC: case SparseFormat::CSC:
return GetInCSR(); return GetInCSR();
case SparseFormat::COO: case SparseFormat::COO:
return GetCOO(); return GetCOO();
case SparseFormat::ANY: case SparseFormat::ANY:
return GetAny(); return GetAny();
default: default:
LOG(FATAL) << "unsupported format code"; LOG(FATAL) << "unsupported format code";
return nullptr; return nullptr;
} }
} }
...@@ -1219,46 +1248,58 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const { ...@@ -1219,46 +1248,58 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
return SparseFormat::COO; return SparseFormat::COO;
} }
UnitGraph* UnitGraph::EmptyGraph() {
auto src = NewIdArray(0);
auto dst = NewIdArray(0);
auto mg = CreateUnitGraphMetaGraph(1);
COOPtr coo(new COO(mg, 0, 0, src, dst));
return new UnitGraph(mg, nullptr, nullptr, coo);
}
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127; constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;
// Using OurCSR
bool UnitGraph::Load(dmlc::Stream* fs) { bool UnitGraph::Load(dmlc::Stream* fs) {
uint64_t magicNum; uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number"; CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data"; CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
uint64_t num_vtypes, num_src, num_dst;
CHECK(fs->Read(&num_vtypes)) << "Invalid num_vtypes"; int64_t format_code;
CHECK(fs->Read(&num_src)) << "Invalid num_src"; CHECK(fs->Read(&format_code)) << "Invalid format";
CHECK(fs->Read(&num_dst)) << "Invalid num_dst"; restrict_format_ = static_cast<SparseFormat>(format_code);
aten::CSRMatrix csr_matrix;
CHECK(fs->Read(&csr_matrix)) << "Invalid csr_matrix"; switch (restrict_format_) {
auto mg = CreateUnitGraphMetaGraph(num_vtypes); case SparseFormat::COO:
CSRPtr csr(new CSR(mg, num_src, num_dst, csr_matrix.indptr, fs->Read(&coo_);
csr_matrix.indices, csr_matrix.data)); break;
*this = UnitGraph(mg, nullptr, csr, nullptr); case SparseFormat::CSR:
fs->Read(&out_csr_);
break;
case SparseFormat::CSC:
fs->Read(&in_csr_);
break;
default:
LOG(FATAL) << "unsupported format code";
break;
}
meta_graph_ = GetAny()->meta_graph();
return true; return true;
} }
// Using Out CSR
void UnitGraph::Save(dmlc::Stream* fs) const { void UnitGraph::Save(dmlc::Stream* fs) const {
// Following CreateFromCSR signature
aten::CSRMatrix csr_matrix = GetCSRMatrix(0);
uint64_t num_vtypes = NumVertexTypes();
uint64_t num_src = NumVertices(SrcType());
uint64_t num_dst = NumVertices(DstType());
fs->Write(kDGLSerialize_UnitGraphMagic); fs->Write(kDGLSerialize_UnitGraphMagic);
fs->Write(num_vtypes); // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
fs->Write(num_src); // sparse matrix
fs->Write(num_dst); auto avail_fmt = SelectFormat(SparseFormat::ANY);
fs->Write(csr_matrix); fs->Write(static_cast<int64_t>(avail_fmt));
switch (avail_fmt) {
case SparseFormat::COO:
fs->Write(GetCOO());
break;
case SparseFormat::CSR:
fs->Write(GetOutCSR());
break;
case SparseFormat::CSC:
fs->Write(GetInCSR());
break;
default:
LOG(FATAL) << "unsupported format code";
break;
}
} }
} // namespace dgl } // namespace dgl
...@@ -215,6 +215,9 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -215,6 +215,9 @@ class UnitGraph : public BaseHeteroGraph {
friend class Serializer; friend class Serializer;
friend class HeteroGraph; friend class HeteroGraph;
// private empty constructor
UnitGraph() {}
/*! /*!
* \brief constructor * \brief constructor
* \param metagraph metagraph * \param metagraph metagraph
...@@ -250,9 +253,6 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -250,9 +253,6 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Whether the graph is hypersparse */ /*! \return Whether the graph is hypersparse */
bool IsHypersparse() const; bool IsHypersparse() const;
// Empty Graph for Serializer Usgae
static UnitGraph* EmptyGraph();
// Graph stored in different format. We use an on-demand strategy: the format is // Graph stored in different format. We use an on-demand strategy: the format is
// only materialized if the operation that suitable for it is invoked. // only materialized if the operation that suitable for it is invoked.
/*! \brief CSR graph that stores reverse edges */ /*! \brief CSR graph that stores reverse edges */
...@@ -275,6 +275,8 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -275,6 +275,8 @@ class UnitGraph : public BaseHeteroGraph {
namespace dmlc { namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph, true); DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph, true);
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph::CSR, true);
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph::COO, true);
} // namespace dmlc } // namespace dmlc
#endif // DGL_GRAPH_UNIT_GRAPH_H_ #endif // DGL_GRAPH_UNIT_GRAPH_H_
#include <dgl/graph_serializer.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "../../src/graph/graph_serializer.h"
#include "../../src/graph/heterograph.h" #include "../../src/graph/heterograph.h"
#include "../../src/graph/unit_graph.h" #include "../../src/graph/unit_graph.h"
#include "./common.h" #include "./common.h"
...@@ -13,50 +13,65 @@ using namespace dgl; ...@@ -13,50 +13,65 @@ using namespace dgl;
using namespace dgl::aten; using namespace dgl::aten;
using namespace dmlc; using namespace dmlc;
TEST(Serialize, DISABLED_UnitGraph) { TEST(Serialize, UnitGraph_COO) {
aten::CSRMatrix csr_matrix; aten::CSRMatrix csr_matrix;
auto src = VecToIdArray<int64_t>({1, 2, 5, 3}); auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6}); auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst); auto mg = std::dynamic_pointer_cast<UnitGraph>(
UnitGraph* ug = dynamic_cast<UnitGraph*>(mg.get()); dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::COO));
std::string blob; std::string blob;
dmlc::MemoryStringStream ifs(&blob); dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream*>(&ifs)->Write<UnitGraph>(*ug); static_cast<dmlc::Stream *>(&ifs)->Write(mg);
dmlc::MemoryStringStream ofs(&blob); dmlc::MemoryStringStream ofs(&blob);
UnitGraph* ug2 = Serializer::EmptyUnitGraph(); auto ug2 = Serializer::make_shared<UnitGraph>();
static_cast<dmlc::Stream*>(&ofs)->Read(ug2); static_cast<dmlc::Stream *>(&ofs)->Read(&ug2);
EXPECT_EQ(ug2->NumVertices(0), 9); EXPECT_EQ(ug2->NumVertices(0), 9);
EXPECT_EQ(ug2->NumVertices(1), 8); EXPECT_EQ(ug2->NumVertices(1), 8);
EXPECT_EQ(ug2->NumEdges(0), 4); EXPECT_EQ(ug2->NumEdges(0), 4);
EXPECT_EQ(ug2->FindEdge(0, 1).first, 2); EXPECT_EQ(ug2->FindEdge(0, 1).first, 2);
EXPECT_EQ(ug2->FindEdge(0, 1).second, 6); EXPECT_EQ(ug2->FindEdge(0, 1).second, 6);
delete ug2;
} }
TEST(Serialize, DISABLED_ImmutableGraph) { TEST(Serialize, UnitGraph_CSR) {
aten::CSRMatrix csr_matrix;
auto src = VecToIdArray<int64_t>({1, 2, 5, 3}); auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6}); auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto gptr = ImmutableGraph::CreateFromCOO(10, src, dst); auto mg = std::dynamic_pointer_cast<UnitGraph>(
ImmutableGraph* rptr = gptr.get(); dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::CSR));
std::string blob;
dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream *>(&ifs)->Write(mg);
dmlc::MemoryStringStream ofs(&blob);
auto ug2 = Serializer::make_shared<UnitGraph>();
static_cast<dmlc::Stream *>(&ofs)->Read(&ug2);
// Query operation is not supported on CSR, how to check it?
}
TEST(Serialize, ImmutableGraph) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto gptr = ImmutableGraph::CreateFromCOO(10, src, dst);
std::string blob; std::string blob;
dmlc::MemoryStringStream ifs(&blob); dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream*>(&ifs)->Write(*rptr); static_cast<dmlc::Stream *>(&ifs)->Write(gptr);
dmlc::MemoryStringStream ofs(&blob); dmlc::MemoryStringStream ofs(&blob);
ImmutableGraph* rptr_read = new ImmutableGraph(static_cast<COOPtr>(nullptr)); auto rptr_read = dgl::Serializer::make_shared<ImmutableGraph>();
static_cast<dmlc::Stream*>(&ofs)->Read(rptr_read); static_cast<dmlc::Stream *>(&ofs)->Read(&rptr_read);
EXPECT_EQ(rptr_read->NumEdges(), 4); EXPECT_EQ(rptr_read->NumEdges(), 4);
EXPECT_EQ(rptr_read->NumVertices(), 10); EXPECT_EQ(rptr_read->NumVertices(), 10);
EXPECT_EQ(rptr_read->FindEdge(2).first, 5); EXPECT_EQ(rptr_read->FindEdge(2).first, 5);
EXPECT_EQ(rptr_read->FindEdge(2).second, 2); EXPECT_EQ(rptr_read->FindEdge(2).second, 2);
delete rptr_read;
} }
TEST(Serialize, DISABLED_HeteroGraph) { TEST(Serialize, HeteroGraph) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3}); auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6}); auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg1 = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst); auto mg1 = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);
...@@ -68,18 +83,16 @@ TEST(Serialize, DISABLED_HeteroGraph) { ...@@ -68,18 +83,16 @@ TEST(Serialize, DISABLED_HeteroGraph) {
relgraphs.push_back(mg2); relgraphs.push_back(mg2);
src = VecToIdArray<int64_t>({0, 0}); src = VecToIdArray<int64_t>({0, 0});
dst = VecToIdArray<int64_t>({1, 0}); dst = VecToIdArray<int64_t>({1, 0});
auto meta_gptr = ImmutableGraph::CreateFromCOO(2, src, dst); auto meta_gptr = ImmutableGraph::CreateFromCOO(3, src, dst);
HeteroGraph* hrptr = new HeteroGraph(meta_gptr, relgraphs); auto hrptr = std::make_shared<HeteroGraph>(meta_gptr, relgraphs);
std::string blob; std::string blob;
dmlc::MemoryStringStream ifs(&blob); dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream*>(&ifs)->Write(*hrptr); static_cast<dmlc::Stream *>(&ifs)->Write(hrptr);
dmlc::MemoryStringStream ofs(&blob); dmlc::MemoryStringStream ofs(&blob);
HeteroGraph* gptr = dgl::Serializer::EmptyHeteroGraph(); auto gptr = dgl::Serializer::make_shared<HeteroGraph>();
static_cast<dmlc::Stream*>(&ofs)->Read(gptr); static_cast<dmlc::Stream *>(&ofs)->Read(&gptr);
EXPECT_EQ(gptr->NumVertices(0), 9); EXPECT_EQ(gptr->NumVertices(0), 9);
EXPECT_EQ(gptr->NumVertices(1), 8); EXPECT_EQ(gptr->NumVertices(1), 8);
delete hrptr;
delete gptr;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment