Commit f103bbf9 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Data format] Serialization for UnitGraph (#1242)



* graph format

* fix lint

* lint

* fix

* unit test

* lint

* add magic num

* move serialize out of struct

* lint
Co-authored-by: default avatarzhoujinjing09 <zhoujinjing09@users.noreply.github.com>
parent 3f0c1005
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#define DGL_ARRAY_H_ #define DGL_ARRAY_H_
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include <tuple> #include <tuple>
...@@ -419,6 +421,8 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -419,6 +421,8 @@ IdArray VecToIdArray(const std::vector<T>& vec,
return ret.CopyTo(ctx); return ret.CopyTo(ctx);
} }
///////////////////////// Dispatchers ////////////////////////// ///////////////////////// Dispatchers //////////////////////////
/* /*
...@@ -597,4 +601,40 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -597,4 +601,40 @@ IdArray VecToIdArray(const std::vector<T>& vec,
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
namespace dmlc {
namespace serializer {
using dgl::aten::CSRMatrix;
constexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127;
template <>
struct Handler<CSRMatrix> {
inline static void Write(Stream* fs, const CSRMatrix& csr) {
fs->Write(kDGLSerialize_AtenCsrMatrixMagic);
fs->Write(csr.num_cols);
fs->Write(csr.num_rows);
fs->Write(csr.indptr);
fs->Write(csr.indices);
fs->Write(csr.data);
fs->Write(csr.sorted);
}
inline static bool Read(Stream* fs, CSRMatrix* csr) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCsrMatrixMagic)
<< "Invalid CSRMatrix Data";
CHECK(fs->Read(&csr->num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&csr->num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&csr->indptr)) << "Invalid indptr";
CHECK(fs->Read(&csr->indices)) << "Invalid indices";
CHECK(fs->Read(&csr->data)) << "Invalid data";
CHECK(fs->Read(&csr->sorted)) << "Invalid sorted";
return true;
}
};
} // namespace serializer
} // namespace dmlc
#endif // DGL_ARRAY_H_ #endif // DGL_ARRAY_H_
...@@ -467,4 +467,9 @@ inline bool NDArray::Load(dmlc::Stream* strm) { ...@@ -467,4 +467,9 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::runtime::NDArray, true);
} // namespace dmlc
#endif // DGL_RUNTIME_NDARRAY_H_ #endif // DGL_RUNTIME_NDARRAY_H_
...@@ -56,7 +56,6 @@ using dgl::serialize::GraphData; ...@@ -56,7 +56,6 @@ using dgl::serialize::GraphData;
using dgl::serialize::GraphDataObject; using dgl::serialize::GraphDataObject;
namespace dmlc { namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, NDArray, true);
DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true); DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true);
} }
......
...@@ -1152,4 +1152,39 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const { ...@@ -1152,4 +1152,39 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
return SparseFormat::COO; return SparseFormat::COO;
} }
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;
bool UnitGraph::Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
uint64_t num_vtypes, num_src, num_dst;
CHECK(fs->Read(&num_vtypes)) << "Invalid num_vtypes";
CHECK(fs->Read(&num_src)) << "Invalid num_src";
CHECK(fs->Read(&num_dst)) << "Invalid num_dst";
aten::CSRMatrix csr_matrix;
CHECK(fs->Read(&csr_matrix)) << "Invalid csr_matrix";
SparseFormat restrict_format;
CHECK(fs->Read(&restrict_format)) << "Invalid restrict_format";
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csr(new CSR(mg, num_src, num_dst, csr_matrix.indptr, csr_matrix.indices, csr_matrix.data));
*this = UnitGraph(mg, csr, nullptr, nullptr);
return true;
}
void UnitGraph::Save(dmlc::Stream* fs) const {
// Following CreateFromCSR signature
aten::CSRMatrix csr_matrix = GetInCSRMatrix();
uint64_t num_vtypes = NumVertexTypes();
uint64_t num_src = NumVertices(SrcType());
uint64_t num_dst = NumVertices(DstType());
SparseFormat restrict_format = restrict_format_;
fs->Write(kDGLSerialize_UnitGraphMagic);
fs->Write(num_vtypes);
fs->Write(num_src);
fs->Write(num_dst);
fs->Write(csr_matrix);
fs->Write(restrict_format);
}
} // namespace dgl } // namespace dgl
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/lazy.h> #include <dgl/lazy.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <utility> #include <utility>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -177,6 +179,12 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -177,6 +179,12 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Return the COO matrix form */ /*! \return Return the COO matrix form */
aten::COOMatrix GetCOOMatrix() const; aten::COOMatrix GetCOOMatrix() const;
/*! \return Load UnitGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs);
/*! \return Save UnitGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const;
private: private:
/*! /*!
* \brief constructor * \brief constructor
...@@ -231,4 +239,8 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -231,4 +239,8 @@ class UnitGraph : public BaseHeteroGraph {
}; // namespace dgl }; // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph, true);
} // namespace dmlc
#endif // DGL_GRAPH_UNIT_GRAPH_H_ #endif // DGL_GRAPH_UNIT_GRAPH_H_
...@@ -335,27 +335,24 @@ void BinaryOpReduce( ...@@ -335,27 +335,24 @@ void BinaryOpReduce(
} }
} }
// Comes from DGLArgValue::AsObjectRef() that allows argvalue to be either a GraphRef
// or a HeteroGraphRef void csrwrapper_switch(DGLArgValue argval,
#define CSRWRAPPER_SWITCH(argvalue, wrapper, ...) do { \ std::function<void(const CSRWrapper&)> fn) {
DGLArgValue argval = (argvalue); \ DGL_CHECK_TYPE_CODE(argval.type_code(), kObjectHandle);
DGL_CHECK_TYPE_CODE(argval.type_code(), kObjectHandle); \ if (argval.IsObjectType<GraphRef>()) {
std::shared_ptr<Object>& sptr = \ GraphRef g = argval;
*argval.ptr<std::shared_ptr<Object>>(); \ auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
if (ObjectTypeChecker<GraphRef>::Check(sptr.get())) { \ CHECK_NOTNULL(igptr);
GraphRef g = argval; \ ImmutableGraphCSRWrapper wrapper(igptr.get());
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); \ fn(wrapper);
CHECK_NOTNULL(igptr); \ } else if (argval.IsObjectType<HeteroGraphRef>()) {
ImmutableGraphCSRWrapper wrapper(igptr.get()); \ HeteroGraphRef g = argval;
{__VA_ARGS__} \ auto bgptr = std::dynamic_pointer_cast<UnitGraph>(g.sptr());
} else if (ObjectTypeChecker<HeteroGraphRef>::Check(sptr.get())) { \ CHECK_NOTNULL(bgptr);
HeteroGraphRef g = argval; \ UnitGraphCSRWrapper wrapper(bgptr.get());
auto bgptr = std::dynamic_pointer_cast<UnitGraph>(g.sptr()); \ fn(wrapper);
CHECK_NOTNULL(bgptr); \ }
UnitGraphCSRWrapper wrapper(bgptr.get()); \ }
{__VA_ARGS__} \
} \
} while (0)
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce") DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
...@@ -370,12 +367,14 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce") ...@@ -370,12 +367,14 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
NDArray rhs_mapping = args[9]; NDArray rhs_mapping = args[9];
NDArray out_mapping = args[10]; NDArray out_mapping = args[10];
CSRWRAPPER_SWITCH(args[2], wrapper, { auto f = [&reducer, &op, &lhs, &rhs, &lhs_data, &rhs_data, &out_data,
BinaryOpReduce(reducer, op, wrapper, &lhs_mapping, &rhs_mapping,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs), &out_mapping](const CSRWrapper& wrapper) {
lhs_data, rhs_data, out_data, BinaryOpReduce(reducer, op, wrapper, static_cast<binary_op::Target>(lhs),
lhs_mapping, rhs_mapping, out_mapping); static_cast<binary_op::Target>(rhs), lhs_data, rhs_data,
}); out_data, lhs_mapping, rhs_mapping, out_mapping);
};
csrwrapper_switch(args[2], f);
}); });
void BackwardLhsBinaryOpReduce( void BackwardLhsBinaryOpReduce(
...@@ -443,14 +442,16 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce") ...@@ -443,14 +442,16 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
NDArray grad_out_data = args[11]; NDArray grad_out_data = args[11];
NDArray grad_lhs_data = args[12]; NDArray grad_lhs_data = args[12];
CSRWRAPPER_SWITCH(args[2], wrapper, { auto f = [&reducer, &op, &lhs, &rhs, &lhs_mapping, &rhs_mapping,
&out_mapping, &lhs_data, &rhs_data, &out_data, &grad_out_data,
&grad_lhs_data](const CSRWrapper& wrapper) {
BackwardLhsBinaryOpReduce( BackwardLhsBinaryOpReduce(
reducer, op, wrapper, reducer, op, wrapper, static_cast<binary_op::Target>(lhs),
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs), static_cast<binary_op::Target>(rhs), lhs_mapping, rhs_mapping,
lhs_mapping, rhs_mapping, out_mapping, out_mapping, lhs_data, rhs_data, out_data, grad_out_data,
lhs_data, rhs_data, out_data, grad_out_data,
grad_lhs_data); grad_lhs_data);
}); };
csrwrapper_switch(args[2], f);
}); });
void BackwardRhsBinaryOpReduce( void BackwardRhsBinaryOpReduce(
...@@ -517,14 +518,17 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce") ...@@ -517,14 +518,17 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
NDArray grad_out_data = args[11]; NDArray grad_out_data = args[11];
NDArray grad_rhs_data = args[12]; NDArray grad_rhs_data = args[12];
CSRWRAPPER_SWITCH(args[2], wrapper, { auto f = [&reducer, &op, &lhs, &rhs, &lhs_mapping, &rhs_mapping,
&out_mapping, &lhs_data, &rhs_data, out_data, &grad_out_data,
&grad_rhs_data](const CSRWrapper& wrapper) {
BackwardRhsBinaryOpReduce( BackwardRhsBinaryOpReduce(
reducer, op, wrapper, reducer, op, wrapper, static_cast<binary_op::Target>(lhs),
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs), static_cast<binary_op::Target>(rhs), lhs_mapping, rhs_mapping,
lhs_mapping, rhs_mapping, out_mapping, out_mapping, lhs_data, rhs_data, out_data, grad_out_data,
lhs_data, rhs_data, out_data, grad_out_data,
grad_rhs_data); grad_rhs_data);
}); };
csrwrapper_switch(args[2], f);
}); });
void CopyReduce( void CopyReduce(
...@@ -557,12 +561,13 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce") ...@@ -557,12 +561,13 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce")
NDArray in_mapping = args[5]; NDArray in_mapping = args[5];
NDArray out_mapping = args[6]; NDArray out_mapping = args[6];
CSRWRAPPER_SWITCH(args[1], wrapper, { auto f = [&reducer, &target, &in_data, &out_data, &in_mapping,
CopyReduce(reducer, wrapper, &out_mapping](const CSRWrapper& wrapper) {
static_cast<binary_op::Target>(target), CopyReduce(reducer, wrapper, static_cast<binary_op::Target>(target),
in_data, out_data, in_data, out_data, in_mapping, out_mapping);
in_mapping, out_mapping); };
});
csrwrapper_switch(args[1], f);
}); });
void BackwardCopyReduce( void BackwardCopyReduce(
...@@ -606,13 +611,14 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce") ...@@ -606,13 +611,14 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
NDArray in_mapping = args[7]; NDArray in_mapping = args[7];
NDArray out_mapping = args[8]; NDArray out_mapping = args[8];
CSRWRAPPER_SWITCH(args[1], wrapper, { auto f = [&reducer, &target, &in_mapping, &out_mapping, &in_data, &out_data,
&grad_out_data, &grad_in_data](const CSRWrapper& wrapper) {
BackwardCopyReduce( BackwardCopyReduce(
reducer, wrapper, static_cast<binary_op::Target>(target), reducer, wrapper, static_cast<binary_op::Target>(target), in_mapping,
in_mapping, out_mapping, out_mapping, in_data, out_data, grad_out_data, grad_in_data);
in_data, out_data, grad_out_data, };
grad_in_data);
}); csrwrapper_switch(args[1], f);
}); });
} // namespace kernel } // namespace kernel
......
#include <dmlc/memory_io.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <vector>
#include "../../src/graph/unit_graph.h"
#include "./common.h"
using namespace dgl;
using namespace dgl::aten;
using namespace dmlc;
TEST(Serialize, UnitGraph) {
aten::CSRMatrix csr_matrix;
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);
UnitGraph* ug = dynamic_cast<UnitGraph*>(mg.get());
std::string blob;
dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream*>(&ifs)->Write<UnitGraph>(*ug);
dmlc::MemoryStringStream ofs(&blob);
src = NewIdArray(0);
dst = NewIdArray(0);
auto mg2 = dgl::UnitGraph::CreateFromCOO(
1, 0, 0, src, dst); // Any way to construct Empty UnitGraph?
UnitGraph* ug2 = dynamic_cast<UnitGraph*>(mg2.get());
static_cast<dmlc::Stream*>(&ofs)->Read(ug2);
EXPECT_EQ(ug2->NumVertices(0), 8);
EXPECT_EQ(ug2->NumVertices(1), 9);
EXPECT_EQ(ug2->NumEdges(0), 4);
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment