Unverified Commit 9caff617 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Feature] More serialization support (#1295)



* fix script

* t

* fix weird bugs

* fix

* fix

* upload

* fix

* fix

* lint

* fix

* tmp

* fix serialization

* fix

* fix lint

* fix message

* fix

* lint

* address comment

* fix

* lint

* fix

* fix

* fix

* Remove duplicate serialization for meta graph
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 09ade2f2
......@@ -293,6 +293,9 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
* that have the same row, col indices. It corresponds to multigraph in
* graph terminology.
*/
constexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127;
struct CSRMatrix {
/*! \brief the dense shape of the matrix */
int64_t num_rows = 0, num_cols = 0;
......@@ -305,24 +308,52 @@ struct CSRMatrix {
/*! \brief default constructor */
CSRMatrix() = default;
/*! \brief constructor */
CSRMatrix(int64_t nrows, int64_t ncols,
IdArray parr, IdArray iarr, IdArray darr = NullArray(),
bool sorted_flag = false)
: num_rows(nrows), num_cols(ncols), indptr(parr), indices(iarr),
data(darr), sorted(sorted_flag) {}
CSRMatrix(int64_t nrows, int64_t ncols, IdArray parr, IdArray iarr,
IdArray darr = NullArray(), bool sorted_flag = false)
: num_rows(nrows),
num_cols(ncols),
indptr(parr),
indices(iarr),
data(darr),
sorted(sorted_flag) {}
/*! \brief constructor from SparseMatrix object */
explicit CSRMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows), num_cols(spmat.num_cols),
indptr(spmat.indices[0]), indices(spmat.indices[1]), data(spmat.indices[2]),
: num_rows(spmat.num_rows),
num_cols(spmat.num_cols),
indptr(spmat.indices[0]),
indices(spmat.indices[1]),
data(spmat.indices[2]),
sorted(spmat.flags[0]) {}
// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::CSR),
num_rows, num_cols,
{indptr, indices, data},
{sorted});
return SparseMatrix(static_cast<int32_t>(SparseFormat::CSR), num_rows,
num_cols, {indptr, indices, data}, {sorted});
}
bool Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCsrMatrixMagic)
<< "Invalid CSRMatrix Data";
CHECK(fs->Read(&num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&indptr)) << "Invalid indptr";
CHECK(fs->Read(&indices)) << "Invalid indices";
CHECK(fs->Read(&data)) << "Invalid data";
CHECK(fs->Read(&sorted)) << "Invalid sorted";
return true;
}
void Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_AtenCsrMatrixMagic);
fs->Write(num_cols);
fs->Write(num_rows);
fs->Write(indptr);
fs->Write(indices);
fs->Write(data);
fs->Write(sorted);
}
};
......@@ -335,6 +366,9 @@ struct CSRMatrix {
* that have the same row, col indices. It corresponds to multigraph in
* graph terminology.
*/
constexpr uint64_t kDGLSerialize_AtenCooMatrixMagic = 0xDD61ffd305dff127;
// TODO(BarclayII): Graph queries on COO formats should support the case where
// data ordered by rows/columns instead of EID.
struct COOMatrix {
......@@ -351,24 +385,57 @@ struct COOMatrix {
/*! \brief default constructor */
COOMatrix() = default;
/*! \brief constructor */
COOMatrix(int64_t nrows, int64_t ncols,
IdArray rarr, IdArray carr, IdArray darr = NullArray(),
bool rsorted = false, bool csorted = false)
: num_rows(nrows), num_cols(ncols), row(rarr), col(carr), data(darr),
row_sorted(rsorted), col_sorted(csorted) {}
COOMatrix(int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,
IdArray darr = NullArray(), bool rsorted = false,
bool csorted = false)
: num_rows(nrows),
num_cols(ncols),
row(rarr),
col(carr),
data(darr),
row_sorted(rsorted),
col_sorted(csorted) {}
/*! \brief constructor from SparseMatrix object */
explicit COOMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows), num_cols(spmat.num_cols),
row(spmat.indices[0]), col(spmat.indices[1]), data(spmat.indices[2]),
row_sorted(spmat.flags[0]), col_sorted(spmat.flags[1]) {}
: num_rows(spmat.num_rows),
num_cols(spmat.num_cols),
row(spmat.indices[0]),
col(spmat.indices[1]),
data(spmat.indices[2]),
row_sorted(spmat.flags[0]),
col_sorted(spmat.flags[1]) {}
// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::COO),
num_rows, num_cols,
{row, col, data},
{row_sorted, col_sorted});
return SparseMatrix(static_cast<int32_t>(SparseFormat::COO), num_rows,
num_cols, {row, col, data}, {row_sorted, col_sorted});
}
bool Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCooMatrixMagic)
<< "Invalid COOMatrix Data";
CHECK(fs->Read(&num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&row)) << "Invalid row";
CHECK(fs->Read(&col)) << "Invalid col";
CHECK(fs->Read(&data)) << "Invalid data";
CHECK(fs->Read(&row_sorted)) << "Invalid row_sorted";
CHECK(fs->Read(&col_sorted)) << "Invalid col_sorted";
return true;
}
void Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_AtenCooMatrixMagic);
fs->Write(num_cols);
fs->Write(num_rows);
fs->Write(row);
fs->Write(col);
fs->Write(data);
fs->Write(row_sorted);
fs->Write(col_sorted);
}
};
......@@ -915,39 +982,8 @@ IdArray VecToIdArray(const std::vector<T>& vec,
} // namespace dgl
namespace dmlc {
namespace serializer {
using dgl::aten::CSRMatrix;
constexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127;
template <>
struct Handler<CSRMatrix> {
inline static void Write(Stream* fs, const CSRMatrix& csr) {
fs->Write(kDGLSerialize_AtenCsrMatrixMagic);
fs->Write(csr.num_cols);
fs->Write(csr.num_rows);
fs->Write(csr.indptr);
fs->Write(csr.indices);
fs->Write(csr.data);
fs->Write(csr.sorted);
}
inline static bool Read(Stream* fs, CSRMatrix* csr) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCsrMatrixMagic)
<< "Invalid CSRMatrix Data";
CHECK(fs->Read(&csr->num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&csr->num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&csr->indptr)) << "Invalid indptr";
CHECK(fs->Read(&csr->indices)) << "Invalid indices";
CHECK(fs->Read(&csr->data)) << "Invalid data";
CHECK(fs->Read(&csr->sorted)) << "Invalid sorted";
return true;
}
};
} // namespace serializer
DMLC_DECLARE_TRAITS(has_saveload, dgl::aten::CSRMatrix, true);
DMLC_DECLARE_TRAITS(has_saveload, dgl::aten::COOMatrix, true);
} // namespace dmlc
#endif // DGL_ARRAY_H_
......@@ -438,6 +438,9 @@ class BaseHeteroGraph : public runtime::Object {
protected:
/*! \brief meta graph */
GraphPtr meta_graph_;
// empty constructor
BaseHeteroGraph(){}
};
// Define HeteroGraphRef
......
......@@ -4,22 +4,25 @@
* \brief DGL serializer APIs
*/
#pragma once
#include <dgl/immutable_graph.h>
#include "heterograph.h"
#include "unit_graph.h"
#ifndef DGL_GRAPH_SERIALIZER_H_
#define DGL_GRAPH_SERIALIZER_H_
#include <memory>
namespace dgl {
// Util class to call the private/public empty constructor, which is needed for serialization
class Serializer {
public:
static HeteroGraph* EmptyHeteroGraph() { return new HeteroGraph(); }
static ImmutableGraph* EmptyImmutableGraph() {
return new ImmutableGraph(static_cast<COOPtr>(nullptr));
template <typename T>
static T* new_object() {
return new T();
}
static UnitGraph* EmptyUnitGraph() {
return UnitGraph::EmptyGraph();
template <typename T>
static std::shared_ptr<T> make_shared() {
return std::shared_ptr<T>(new T());
}
};
} // namespace dgl
#endif // DGL_GRAPH_SERIALIZER_H_
......@@ -240,6 +240,12 @@ class CSR : public GraphInterface {
IdArray edge_ids() const { return adj_.data; }
/*! \return Load CSR from stream */
bool Load(dmlc::Stream *fs);
/*! \return Save CSR to stream */
void Save(dmlc::Stream* fs) const;
void SortCSR() override {
if (adj_.sorted)
return;
......@@ -247,9 +253,10 @@ class CSR : public GraphInterface {
}
private:
/*! \brief prive default constructor */
CSR() {adj_.sorted = false;}
friend class Serializer;
/*! \brief private default constructor */
CSR() {adj_.sorted = false;}
// The internal CSR adjacency matrix.
// The data field stores edge ids.
aten::CSRMatrix adj_;
......@@ -957,10 +964,10 @@ class ImmutableGraph: public GraphInterface {
*/
ImmutableGraphPtr Reverse() const;
/*! \return Load HeteroGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs);
/*! \return Load ImmutableGraph from stream, using out csr */
bool Load(dmlc::Stream *fs);
/*! \return Save HeteroGraph to stream, using CSRMatrix */
/*! \return Save ImmutableGraph to stream, using out csr */
void Save(dmlc::Stream* fs) const;
void SortCSR() {
......@@ -969,6 +976,8 @@ class ImmutableGraph: public GraphInterface {
}
protected:
friend class Serializer;
/* !\brief internal default constructor */
ImmutableGraph() {}
......@@ -1034,6 +1043,7 @@ CSR::CSR(int64_t num_vertices, int64_t num_edges,
} // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::CSR, true);
DMLC_DECLARE_TRAITS(has_saveload, dgl::ImmutableGraph, true);
} // namespace dmlc
......
......@@ -10,6 +10,7 @@
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "c_runtime_api.h"
#include "smart_ptr_serializer.h"
#include "ndarray.h"
namespace dmlc {
......
......@@ -7,6 +7,7 @@
#ifndef DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
#define DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
#include <dgl/graph_serializer.h>
#include <dmlc/io.h>
#include <dmlc/serializer.h>
......@@ -24,7 +25,7 @@ struct Handler<std::shared_ptr<T>> {
// shared_ptr<T>(), which is holding a nullptr. Here we need to manually
// reset to a real object for further loading
if (!(*data)) {
data->reset(new T());
data->reset(dgl::Serializer::new_object<T>());
}
return Handler<T>::Read(strm, data->get());
}
......@@ -40,7 +41,7 @@ struct Handler<std::unique_ptr<T>> {
// unique_ptr<T>(), which is holding a nullptr. Here we need to manually
// reset to a real object for further loading
if (!(*data)) {
data->reset(new T());
data->reset(dgl::Serializer::new_object<T>());
}
return Handler<T>::Read(strm, data->get());
}
......
......@@ -4,14 +4,12 @@
* \brief Heterograph implementation
*/
#include "./heterograph.h"
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <dgl/array.h>
#include <dgl/immutable_graph.h>
#include <dgl/graph_serializer.h>
#include <vector>
#include <tuple>
#include <utility>
#include "graph_serializer.h"
using namespace dgl::runtime;
......@@ -116,7 +114,7 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
// all relation graphs must have only one edge type
for (const auto rg : rel_graphs) {
for (const auto &rg : rel_graphs) {
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
}
// create num verts per type
......@@ -166,7 +164,7 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
bool HeteroGraph::IsMultigraph() const {
return const_cast<HeteroGraph*>(this)->is_multigraph_.Get([this] () {
for (const auto hg : relation_graphs_) {
for (const auto &hg : relation_graphs_) {
if (hg->IsMultigraph()) {
return true;
}
......@@ -316,31 +314,20 @@ bool HeteroGraph::Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_HeteroGraph) << "Invalid HeteroGraph Data";
auto meta_grptr = new ImmutableGraph(static_cast<COOPtr>(nullptr));
CHECK(fs->Read(meta_grptr)) << "Invalid Immutable Graph Data";
uint64_t num_relation_graphs;
CHECK(fs->Read(&num_relation_graphs)) << "Invalid num of relation graphs";
std::vector<HeteroGraphPtr> relgraphs;
for (size_t i = 0; i < num_relation_graphs; ++i) {
UnitGraph* ugptr = Serializer::EmptyUnitGraph();
CHECK(fs->Read(ugptr)) << "Invalid UnitGraph Data";
relgraphs.emplace_back(dynamic_cast<BaseHeteroGraph*>(ugptr));
}
HeteroGraph* hgptr = new HeteroGraph(GraphPtr(meta_grptr), relgraphs);
*this = *hgptr;
auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
meta_graph_ = meta_imgraph;
CHECK(fs->Read(&relation_graphs_)) << "Invalid relation_graphs_";
CHECK(fs->Read(&num_verts_per_type_)) << "Invalid num_verts_per_type_";
return true;
}
void HeteroGraph::Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_HeteroGraph);
auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
ImmutableGraph* meta_rptr = meta_graph_ptr.get();
fs->Write(*meta_rptr);
fs->Write(static_cast<uint64_t>(relation_graphs_.size()));
for (auto hptr : relation_graphs_) {
auto rptr = dynamic_cast<UnitGraph*>(hptr.get());
fs->Write(*rptr);
}
fs->Write(meta_graph_ptr);
fs->Write(relation_graphs_);
fs->Write(num_verts_per_type_);
}
} // namespace dgl
......@@ -198,7 +198,7 @@ class HeteroGraph : public BaseHeteroGraph {
friend class Serializer;
// Empty Constructor, only for serializer
HeteroGraph() : BaseHeteroGraph(static_cast<GraphPtr>(nullptr)) {}
HeteroGraph() : BaseHeteroGraph() {}
/*! \brief A map from edge type to unit graph */
std::vector<UnitGraphPtr> relation_graphs_;
......
......@@ -4,8 +4,9 @@
* \brief DGL immutable graph index implementation
*/
#include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/smart_ptr_serializer.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <string.h>
......@@ -276,6 +277,15 @@ DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
return DGLIdIters(eid_data + start, eid_data + end);
}
bool CSR::Load(dmlc::Stream *fs) {
fs->Read(const_cast<dgl::aten::CSRMatrix*>(&adj_));
return true;
}
void CSR::Save(dmlc::Stream *fs) const {
fs->Write(adj_);
}
//////////////////////////////////////////////////////////
//
// COO graph implementation
......@@ -643,19 +653,16 @@ bool ImmutableGraph::Load(dmlc::Stream *fs) {
uint64_t magicNum;
aten::CSRMatrix out_csr_matrix;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_ImGraph) << "Invalid ImmutableGraph Data";
CHECK(fs->Read(&out_csr_matrix)) << "Invalid csr matrix";
CSRPtr csr(new CSR(out_csr_matrix.indptr, out_csr_matrix.indices,
out_csr_matrix.data));
auto g = new ImmutableGraph(nullptr, csr);
*this = *g;
CHECK_EQ(magicNum, kDGLSerialize_ImGraph)
<< "Invalid ImmutableGraph Magic Number";
CHECK(fs->Read(&out_csr_)) << "Invalid csr matrix";
return true;
}
/*! \return Save HeteroGraph to stream, using OutCSR Matrix */
void ImmutableGraph::Save(dmlc::Stream *fs) const {
fs->Write(kDGLSerialize_ImGraph);
fs->Write(GetOutCSR()->ToCSRMatrix());
fs->Write(GetOutCSR());
}
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
......
......@@ -4,12 +4,12 @@
* \brief UnitGraph graph implementation
*/
#include <dgl/array.h>
#include <dgl/lazy.h>
#include <dgl/immutable_graph.h>
#include <dgl/base_heterograph.h>
#include <dgl/immutable_graph.h>
#include <dgl/lazy.h>
#include "./unit_graph.h"
#include "../c_api_common.h"
#include "./unit_graph.h"
namespace dgl {
......@@ -401,7 +401,24 @@ class UnitGraph::COO : public BaseHeteroGraph {
(NumVertices(SrcType()) > 1000000);
}
bool Load(dmlc::Stream* fs) {
auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
meta_graph_ = meta_imgraph;
CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
return true;
}
void Save(dmlc::Stream* fs) const {
auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
fs->Write(meta_graph_ptr);
fs->Write(adj_);
}
private:
friend class Serializer;
COO() {}
/*! \brief internal adjacency matrix. Data array is empty */
aten::COOMatrix adj_;
......@@ -427,7 +444,6 @@ class UnitGraph::CSR : public BaseHeteroGraph {
CHECK_EQ(indices->shape[0], edge_ids->shape[0])
<< "indices and edge id arrays should have the same length";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
sorted_ = false;
}
CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
......@@ -439,12 +455,10 @@ class UnitGraph::CSR : public BaseHeteroGraph {
CHECK_EQ(indices->shape[0], edge_ids->shape[0])
<< "indices and edge id arrays should have the same length";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
sorted_ = false;
}
CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
: BaseHeteroGraph(metagraph), adj_(csr) {
sorted_ = false;
}
inline dgl_type_t SrcType() const {
......@@ -738,15 +752,30 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return adj_;
}
bool Load(dmlc::Stream* fs) {
auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
meta_graph_ = meta_imgraph;
CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
return true;
}
void Save(dmlc::Stream* fs) const {
auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
fs->Write(meta_graph_ptr);
fs->Write(adj_);
}
private:
friend class Serializer;
CSR() {};
/*! \brief internal adjacency matrix. Data array stores edge ids */
aten::CSRMatrix adj_;
/*! \brief multi-graph flag */
Lazy<bool> is_multigraph_;
/*! \brief indicate that the edges are stored in the sorted order. */
bool sorted_;
};
//////////////////////////////////////////////////////////
......@@ -1219,46 +1248,58 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
return SparseFormat::COO;
}
UnitGraph* UnitGraph::EmptyGraph() {
auto src = NewIdArray(0);
auto dst = NewIdArray(0);
auto mg = CreateUnitGraphMetaGraph(1);
COOPtr coo(new COO(mg, 0, 0, src, dst));
return new UnitGraph(mg, nullptr, nullptr, coo);
}
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;
// Using OurCSR
bool UnitGraph::Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
uint64_t num_vtypes, num_src, num_dst;
CHECK(fs->Read(&num_vtypes)) << "Invalid num_vtypes";
CHECK(fs->Read(&num_src)) << "Invalid num_src";
CHECK(fs->Read(&num_dst)) << "Invalid num_dst";
aten::CSRMatrix csr_matrix;
CHECK(fs->Read(&csr_matrix)) << "Invalid csr_matrix";
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csr(new CSR(mg, num_src, num_dst, csr_matrix.indptr,
csr_matrix.indices, csr_matrix.data));
*this = UnitGraph(mg, nullptr, csr, nullptr);
int64_t format_code;
CHECK(fs->Read(&format_code)) << "Invalid format";
restrict_format_ = static_cast<SparseFormat>(format_code);
switch (restrict_format_) {
case SparseFormat::COO:
fs->Read(&coo_);
break;
case SparseFormat::CSR:
fs->Read(&out_csr_);
break;
case SparseFormat::CSC:
fs->Read(&in_csr_);
break;
default:
LOG(FATAL) << "unsupported format code";
break;
}
meta_graph_ = GetAny()->meta_graph();
return true;
}
// Using Out CSR
void UnitGraph::Save(dmlc::Stream* fs) const {
// Following CreateFromCSR signature
aten::CSRMatrix csr_matrix = GetCSRMatrix(0);
uint64_t num_vtypes = NumVertexTypes();
uint64_t num_src = NumVertices(SrcType());
uint64_t num_dst = NumVertices(DstType());
fs->Write(kDGLSerialize_UnitGraphMagic);
fs->Write(num_vtypes);
fs->Write(num_src);
fs->Write(num_dst);
fs->Write(csr_matrix);
// Didn't write UnitGraph::meta_graph_, since it's included in the underlying
// sparse matrix
auto avail_fmt = SelectFormat(SparseFormat::ANY);
fs->Write(static_cast<int64_t>(avail_fmt));
switch (avail_fmt) {
case SparseFormat::COO:
fs->Write(GetCOO());
break;
case SparseFormat::CSR:
fs->Write(GetOutCSR());
break;
case SparseFormat::CSC:
fs->Write(GetInCSR());
break;
default:
LOG(FATAL) << "unsupported format code";
break;
}
}
} // namespace dgl
......@@ -215,6 +215,9 @@ class UnitGraph : public BaseHeteroGraph {
friend class Serializer;
friend class HeteroGraph;
// private empty constructor
UnitGraph() {}
/*!
* \brief constructor
* \param metagraph metagraph
......@@ -250,9 +253,6 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Whether the graph is hypersparse */
bool IsHypersparse() const;
// Empty Graph for Serializer Usgae
static UnitGraph* EmptyGraph();
// Graph stored in different format. We use an on-demand strategy: the format is
// only materialized if the operation that suitable for it is invoked.
/*! \brief CSR graph that stores reverse edges */
......@@ -275,6 +275,8 @@ class UnitGraph : public BaseHeteroGraph {
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph, true);
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph::CSR, true);
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph::COO, true);
} // namespace dmlc
#endif // DGL_GRAPH_UNIT_GRAPH_H_
#include <dgl/graph_serializer.h>
#include <dgl/immutable_graph.h>
#include <dmlc/memory_io.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <vector>
#include "../../src/graph/graph_serializer.h"
#include "../../src/graph/heterograph.h"
#include "../../src/graph/unit_graph.h"
#include "./common.h"
......@@ -13,50 +13,65 @@ using namespace dgl;
using namespace dgl::aten;
using namespace dmlc;
TEST(Serialize, DISABLED_UnitGraph) {
TEST(Serialize, UnitGraph_COO) {
aten::CSRMatrix csr_matrix;
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);
UnitGraph* ug = dynamic_cast<UnitGraph*>(mg.get());
auto mg = std::dynamic_pointer_cast<UnitGraph>(
dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::COO));
std::string blob;
dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream*>(&ifs)->Write<UnitGraph>(*ug);
static_cast<dmlc::Stream *>(&ifs)->Write(mg);
dmlc::MemoryStringStream ofs(&blob);
UnitGraph* ug2 = Serializer::EmptyUnitGraph();
static_cast<dmlc::Stream*>(&ofs)->Read(ug2);
auto ug2 = Serializer::make_shared<UnitGraph>();
static_cast<dmlc::Stream *>(&ofs)->Read(&ug2);
EXPECT_EQ(ug2->NumVertices(0), 9);
EXPECT_EQ(ug2->NumVertices(1), 8);
EXPECT_EQ(ug2->NumEdges(0), 4);
EXPECT_EQ(ug2->FindEdge(0, 1).first, 2);
EXPECT_EQ(ug2->FindEdge(0, 1).second, 6);
delete ug2;
}
TEST(Serialize, DISABLED_ImmutableGraph) {
TEST(Serialize, UnitGraph_CSR) {
aten::CSRMatrix csr_matrix;
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto gptr = ImmutableGraph::CreateFromCOO(10, src, dst);
ImmutableGraph* rptr = gptr.get();
auto mg = std::dynamic_pointer_cast<UnitGraph>(
dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::CSR));
std::string blob;
dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream *>(&ifs)->Write(mg);
dmlc::MemoryStringStream ofs(&blob);
auto ug2 = Serializer::make_shared<UnitGraph>();
static_cast<dmlc::Stream *>(&ofs)->Read(&ug2);
// Query operation is not supported on CSR, how to check it?
}
TEST(Serialize, ImmutableGraph) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto gptr = ImmutableGraph::CreateFromCOO(10, src, dst);
std::string blob;
dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream*>(&ifs)->Write(*rptr);
static_cast<dmlc::Stream *>(&ifs)->Write(gptr);
dmlc::MemoryStringStream ofs(&blob);
ImmutableGraph* rptr_read = new ImmutableGraph(static_cast<COOPtr>(nullptr));
static_cast<dmlc::Stream*>(&ofs)->Read(rptr_read);
auto rptr_read = dgl::Serializer::make_shared<ImmutableGraph>();
static_cast<dmlc::Stream *>(&ofs)->Read(&rptr_read);
EXPECT_EQ(rptr_read->NumEdges(), 4);
EXPECT_EQ(rptr_read->NumVertices(), 10);
EXPECT_EQ(rptr_read->FindEdge(2).first, 5);
EXPECT_EQ(rptr_read->FindEdge(2).second, 2);
delete rptr_read;
}
TEST(Serialize, DISABLED_HeteroGraph) {
TEST(Serialize, HeteroGraph) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg1 = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);
......@@ -68,18 +83,16 @@ TEST(Serialize, DISABLED_HeteroGraph) {
relgraphs.push_back(mg2);
src = VecToIdArray<int64_t>({0, 0});
dst = VecToIdArray<int64_t>({1, 0});
auto meta_gptr = ImmutableGraph::CreateFromCOO(2, src, dst);
HeteroGraph* hrptr = new HeteroGraph(meta_gptr, relgraphs);
auto meta_gptr = ImmutableGraph::CreateFromCOO(3, src, dst);
auto hrptr = std::make_shared<HeteroGraph>(meta_gptr, relgraphs);
std::string blob;
dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream*>(&ifs)->Write(*hrptr);
static_cast<dmlc::Stream *>(&ifs)->Write(hrptr);
dmlc::MemoryStringStream ofs(&blob);
HeteroGraph* gptr = dgl::Serializer::EmptyHeteroGraph();
static_cast<dmlc::Stream*>(&ofs)->Read(gptr);
auto gptr = dgl::Serializer::make_shared<HeteroGraph>();
static_cast<dmlc::Stream *>(&ofs)->Read(&gptr);
EXPECT_EQ(gptr->NumVertices(0), 9);
EXPECT_EQ(gptr->NumVertices(1), 8);
delete hrptr;
delete gptr;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment