Commit f103bbf9 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Data format] Serialization for UnitGraph (#1242)



* graph format

* fix lint

* lint

* fix

* unit test

* lint

* add magic num

* move serialize out of struct

* lint
Co-authored-by: default avatarzhoujinjing09 <zhoujinjing09@users.noreply.github.com>
parent 3f0c1005
......@@ -10,6 +10,8 @@
#define DGL_ARRAY_H_
#include <dgl/runtime/ndarray.h>
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include <algorithm>
#include <vector>
#include <tuple>
......@@ -419,6 +421,8 @@ IdArray VecToIdArray(const std::vector<T>& vec,
return ret.CopyTo(ctx);
}
///////////////////////// Dispatchers //////////////////////////
/*
......@@ -597,4 +601,40 @@ IdArray VecToIdArray(const std::vector<T>& vec,
} // namespace aten
} // namespace dgl
namespace dmlc {
namespace serializer {
using dgl::aten::CSRMatrix;
constexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127;
template <>
struct Handler<CSRMatrix> {
inline static void Write(Stream* fs, const CSRMatrix& csr) {
fs->Write(kDGLSerialize_AtenCsrMatrixMagic);
fs->Write(csr.num_cols);
fs->Write(csr.num_rows);
fs->Write(csr.indptr);
fs->Write(csr.indices);
fs->Write(csr.data);
fs->Write(csr.sorted);
}
inline static bool Read(Stream* fs, CSRMatrix* csr) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCsrMatrixMagic)
<< "Invalid CSRMatrix Data";
CHECK(fs->Read(&csr->num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&csr->num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&csr->indptr)) << "Invalid indptr";
CHECK(fs->Read(&csr->indices)) << "Invalid indices";
CHECK(fs->Read(&csr->data)) << "Invalid data";
CHECK(fs->Read(&csr->sorted)) << "Invalid sorted";
return true;
}
};
} // namespace serializer
} // namespace dmlc
#endif // DGL_ARRAY_H_
......@@ -467,4 +467,9 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
} // namespace runtime
} // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::runtime::NDArray, true);
} // namespace dmlc
#endif // DGL_RUNTIME_NDARRAY_H_
......@@ -56,7 +56,6 @@ using dgl::serialize::GraphData;
using dgl::serialize::GraphDataObject;
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, NDArray, true);
DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true);
}
......
......@@ -1152,4 +1152,39 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
return SparseFormat::COO;
}
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;
bool UnitGraph::Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
uint64_t num_vtypes, num_src, num_dst;
CHECK(fs->Read(&num_vtypes)) << "Invalid num_vtypes";
CHECK(fs->Read(&num_src)) << "Invalid num_src";
CHECK(fs->Read(&num_dst)) << "Invalid num_dst";
aten::CSRMatrix csr_matrix;
CHECK(fs->Read(&csr_matrix)) << "Invalid csr_matrix";
SparseFormat restrict_format;
CHECK(fs->Read(&restrict_format)) << "Invalid restrict_format";
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csr(new CSR(mg, num_src, num_dst, csr_matrix.indptr, csr_matrix.indices, csr_matrix.data));
*this = UnitGraph(mg, csr, nullptr, nullptr);
return true;
}
void UnitGraph::Save(dmlc::Stream* fs) const {
// Following CreateFromCSR signature
aten::CSRMatrix csr_matrix = GetInCSRMatrix();
uint64_t num_vtypes = NumVertexTypes();
uint64_t num_src = NumVertices(SrcType());
uint64_t num_dst = NumVertices(DstType());
SparseFormat restrict_format = restrict_format_;
fs->Write(kDGLSerialize_UnitGraphMagic);
fs->Write(num_vtypes);
fs->Write(num_src);
fs->Write(num_dst);
fs->Write(csr_matrix);
fs->Write(restrict_format);
}
} // namespace dgl
......@@ -10,6 +10,8 @@
#include <dgl/base_heterograph.h>
#include <dgl/lazy.h>
#include <dgl/array.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <utility>
#include <string>
#include <vector>
......@@ -177,6 +179,12 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Return the COO matrix form */
aten::COOMatrix GetCOOMatrix() const;
/*! \return Load UnitGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs);
/*! \return Save UnitGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const;
private:
/*!
* \brief constructor
......@@ -231,4 +239,8 @@ class UnitGraph : public BaseHeteroGraph {
}; // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph, true);
} // namespace dmlc
#endif // DGL_GRAPH_UNIT_GRAPH_H_
......@@ -335,27 +335,24 @@ void BinaryOpReduce(
}
}
// Comes from DGLArgValue::AsObjectRef() that allows argvalue to be either a GraphRef
// or a HeteroGraphRef
#define CSRWRAPPER_SWITCH(argvalue, wrapper, ...) do { \
DGLArgValue argval = (argvalue); \
DGL_CHECK_TYPE_CODE(argval.type_code(), kObjectHandle); \
std::shared_ptr<Object>& sptr = \
*argval.ptr<std::shared_ptr<Object>>(); \
if (ObjectTypeChecker<GraphRef>::Check(sptr.get())) { \
GraphRef g = argval; \
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); \
CHECK_NOTNULL(igptr); \
ImmutableGraphCSRWrapper wrapper(igptr.get()); \
{__VA_ARGS__} \
} else if (ObjectTypeChecker<HeteroGraphRef>::Check(sptr.get())) { \
HeteroGraphRef g = argval; \
auto bgptr = std::dynamic_pointer_cast<UnitGraph>(g.sptr()); \
CHECK_NOTNULL(bgptr); \
UnitGraphCSRWrapper wrapper(bgptr.get()); \
{__VA_ARGS__} \
} \
} while (0)
void csrwrapper_switch(DGLArgValue argval,
std::function<void(const CSRWrapper&)> fn) {
DGL_CHECK_TYPE_CODE(argval.type_code(), kObjectHandle);
if (argval.IsObjectType<GraphRef>()) {
GraphRef g = argval;
auto igptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK_NOTNULL(igptr);
ImmutableGraphCSRWrapper wrapper(igptr.get());
fn(wrapper);
} else if (argval.IsObjectType<HeteroGraphRef>()) {
HeteroGraphRef g = argval;
auto bgptr = std::dynamic_pointer_cast<UnitGraph>(g.sptr());
CHECK_NOTNULL(bgptr);
UnitGraphCSRWrapper wrapper(bgptr.get());
fn(wrapper);
}
}
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
......@@ -370,12 +367,14 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBinaryOpReduce")
NDArray rhs_mapping = args[9];
NDArray out_mapping = args[10];
CSRWRAPPER_SWITCH(args[2], wrapper, {
BinaryOpReduce(reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_data, rhs_data, out_data,
lhs_mapping, rhs_mapping, out_mapping);
});
auto f = [&reducer, &op, &lhs, &rhs, &lhs_data, &rhs_data, &out_data,
&lhs_mapping, &rhs_mapping,
&out_mapping](const CSRWrapper& wrapper) {
BinaryOpReduce(reducer, op, wrapper, static_cast<binary_op::Target>(lhs),
static_cast<binary_op::Target>(rhs), lhs_data, rhs_data,
out_data, lhs_mapping, rhs_mapping, out_mapping);
};
csrwrapper_switch(args[2], f);
});
void BackwardLhsBinaryOpReduce(
......@@ -443,14 +442,16 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardLhsBinaryOpReduce")
NDArray grad_out_data = args[11];
NDArray grad_lhs_data = args[12];
CSRWRAPPER_SWITCH(args[2], wrapper, {
auto f = [&reducer, &op, &lhs, &rhs, &lhs_mapping, &rhs_mapping,
&out_mapping, &lhs_data, &rhs_data, &out_data, &grad_out_data,
&grad_lhs_data](const CSRWrapper& wrapper) {
BackwardLhsBinaryOpReduce(
reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data,
reducer, op, wrapper, static_cast<binary_op::Target>(lhs),
static_cast<binary_op::Target>(rhs), lhs_mapping, rhs_mapping,
out_mapping, lhs_data, rhs_data, out_data, grad_out_data,
grad_lhs_data);
});
};
csrwrapper_switch(args[2], f);
});
void BackwardRhsBinaryOpReduce(
......@@ -517,14 +518,17 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardRhsBinaryOpReduce")
NDArray grad_out_data = args[11];
NDArray grad_rhs_data = args[12];
CSRWRAPPER_SWITCH(args[2], wrapper, {
auto f = [&reducer, &op, &lhs, &rhs, &lhs_mapping, &rhs_mapping,
&out_mapping, &lhs_data, &rhs_data, out_data, &grad_out_data,
&grad_rhs_data](const CSRWrapper& wrapper) {
BackwardRhsBinaryOpReduce(
reducer, op, wrapper,
static_cast<binary_op::Target>(lhs), static_cast<binary_op::Target>(rhs),
lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data,
reducer, op, wrapper, static_cast<binary_op::Target>(lhs),
static_cast<binary_op::Target>(rhs), lhs_mapping, rhs_mapping,
out_mapping, lhs_data, rhs_data, out_data, grad_out_data,
grad_rhs_data);
});
};
csrwrapper_switch(args[2], f);
});
void CopyReduce(
......@@ -557,12 +561,13 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce")
NDArray in_mapping = args[5];
NDArray out_mapping = args[6];
CSRWRAPPER_SWITCH(args[1], wrapper, {
CopyReduce(reducer, wrapper,
static_cast<binary_op::Target>(target),
in_data, out_data,
in_mapping, out_mapping);
});
auto f = [&reducer, &target, &in_data, &out_data, &in_mapping,
&out_mapping](const CSRWrapper& wrapper) {
CopyReduce(reducer, wrapper, static_cast<binary_op::Target>(target),
in_data, out_data, in_mapping, out_mapping);
};
csrwrapper_switch(args[1], f);
});
void BackwardCopyReduce(
......@@ -606,13 +611,14 @@ DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
NDArray in_mapping = args[7];
NDArray out_mapping = args[8];
CSRWRAPPER_SWITCH(args[1], wrapper, {
auto f = [&reducer, &target, &in_mapping, &out_mapping, &in_data, &out_data,
&grad_out_data, &grad_in_data](const CSRWrapper& wrapper) {
BackwardCopyReduce(
reducer, wrapper, static_cast<binary_op::Target>(target),
in_mapping, out_mapping,
in_data, out_data, grad_out_data,
grad_in_data);
});
reducer, wrapper, static_cast<binary_op::Target>(target), in_mapping,
out_mapping, in_data, out_data, grad_out_data, grad_in_data);
};
csrwrapper_switch(args[1], f);
});
} // namespace kernel
......
#include <dmlc/memory_io.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <vector>
#include "../../src/graph/unit_graph.h"
#include "./common.h"
using namespace dgl;
using namespace dgl::aten;
using namespace dmlc;
TEST(Serialize, UnitGraph) {
aten::CSRMatrix csr_matrix;
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);
UnitGraph* ug = dynamic_cast<UnitGraph*>(mg.get());
std::string blob;
dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream*>(&ifs)->Write<UnitGraph>(*ug);
dmlc::MemoryStringStream ofs(&blob);
src = NewIdArray(0);
dst = NewIdArray(0);
auto mg2 = dgl::UnitGraph::CreateFromCOO(
1, 0, 0, src, dst); // Any way to construct Empty UnitGraph?
UnitGraph* ug2 = dynamic_cast<UnitGraph*>(mg2.get());
static_cast<dmlc::Stream*>(&ofs)->Read(ug2);
EXPECT_EQ(ug2->NumVertices(0), 8);
EXPECT_EQ(ug2->NumVertices(1), 9);
EXPECT_EQ(ug2->NumEdges(0), 4);
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment