Unverified Commit 0a4e8b32 authored by Qidong Su's avatar Qidong Su Committed by GitHub
Browse files

[Optimization] Optimize pickling of heterograph (#1570)



* update

* update

* update

* update

* update

* update

* backward compatibility

* update

* update

* update

* update

* update

* test

* test

* update

* update

* update

* update

* update

* update
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent dc86bd42
...@@ -739,6 +739,22 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -739,6 +739,22 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
* This class can be used as arguments and return values of a C API. * This class can be used as arguments and return values of a C API.
*/ */
struct HeteroPickleStates : public runtime::Object { struct HeteroPickleStates : public runtime::Object {
/*! \brief version number */
int64_t version = 0;
/*! \brief Metainformation
*
* metagraph, number of nodes per type, format, flags
*/
std::string meta;
/*! \brief Arrays representing graph structure (coo or csr) */
std::vector<IdArray> arrays;
/* To support backward compatibility, we have to retain fields in the old
* version of HeteroPickleStates
*/
/*! \brief Metagraph(64bits ImmutableGraph) */ /*! \brief Metagraph(64bits ImmutableGraph) */
GraphPtr metagraph; GraphPtr metagraph;
...@@ -766,10 +782,18 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states); ...@@ -766,10 +782,18 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states);
/*! /*!
* \brief Get the pickling state of the relation graph structure in backend tensors. * \brief Get the pickling state of the relation graph structure in backend tensors.
* *
* \returnAdjacency matrices of all relation graphs in a list of arrays. * \return a HeteroPickleStates object
*/ */
HeteroPickleStates HeteroPickle(HeteroGraphPtr graph); HeteroPickleStates HeteroPickle(HeteroGraphPtr graph);
/*!
* \brief Old version of HeteroUnpickle, for backward compatibility
*
* \param states Pickle states
* \return A heterograph pointer
*/
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states);
} // namespace dgl } // namespace dgl
#endif // DGL_BASE_HETEROGRAPH_H_ #endif // DGL_BASE_HETEROGRAPH_H_
...@@ -17,6 +17,7 @@ from .frame import Frame, FrameRef, frame_like ...@@ -17,6 +17,7 @@ from .frame import Frame, FrameRef, frame_like
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning
from .udf import NodeBatch, EdgeBatch from .udf import NodeBatch, EdgeBatch
from ._ffi.function import _init_api
__all__ = ['DGLHeteroGraph', 'combine_names'] __all__ = ['DGLHeteroGraph', 'combine_names']
...@@ -221,7 +222,10 @@ class DGLHeteroGraph(object): ...@@ -221,7 +222,10 @@ class DGLHeteroGraph(object):
self._canonical_etypes = [(ntypes[0][0], etypes[0], ntypes[1][0])] self._canonical_etypes = [(ntypes[0][0], etypes[0], ntypes[1][0])]
else: else:
self._ntypes = ntypes self._ntypes = ntypes
src_dst_map = find_src_dst_ntypes(self._ntypes, self._graph.metagraph) if len(ntypes) == 1:
src_dst_map = None
else:
src_dst_map = find_src_dst_ntypes(self._ntypes, self._graph.metagraph)
self._is_unibipartite = (src_dst_map is not None) self._is_unibipartite = (src_dst_map is not None)
if self._is_unibipartite: if self._is_unibipartite:
self._srctypes_invmap, self._dsttypes_invmap = src_dst_map self._srctypes_invmap, self._dsttypes_invmap = src_dst_map
...@@ -232,8 +236,11 @@ class DGLHeteroGraph(object): ...@@ -232,8 +236,11 @@ class DGLHeteroGraph(object):
# Handle edge types # Handle edge types
self._etypes = etypes self._etypes = etypes
if self._canonical_etypes is None: if self._canonical_etypes is None:
self._canonical_etypes = make_canonical_etypes( if (len(etypes) == 1 and len(ntypes) == 1):
self._etypes, self._ntypes, self._graph.metagraph) self._canonical_etypes = [(ntypes[0], etypes[0], ntypes[0])]
else:
self._canonical_etypes = make_canonical_etypes(
self._etypes, self._ntypes, self._graph.metagraph)
# An internal map from etype to canonical etype tuple. # An internal map from etype to canonical etype tuple.
# If two etypes have the same name, an empty tuple is stored instead to indicate # If two etypes have the same name, an empty tuple is stored instead to indicate
...@@ -4475,7 +4482,9 @@ def make_canonical_etypes(etypes, ntypes, metagraph): ...@@ -4475,7 +4482,9 @@ def make_canonical_etypes(etypes, ntypes, metagraph):
raise DGLError('Length of nodes type list must match the number of ' raise DGLError('Length of nodes type list must match the number of '
'nodes in the metagraph. {} vs {}'.format( 'nodes in the metagraph. {} vs {}'.format(
len(ntypes), metagraph.number_of_nodes())) len(ntypes), metagraph.number_of_nodes()))
src, dst, eid = metagraph.edges() if (len(etypes) == 1 and len(ntypes) == 1):
return [(ntypes[0], etypes[0], ntypes[0])]
src, dst, eid = metagraph.edges(order="eid")
rst = [(ntypes[sid], etypes[eid], ntypes[did]) for sid, did, eid in zip(src, dst, eid)] rst = [(ntypes[sid], etypes[eid], ntypes[did]) for sid, did, eid in zip(src, dst, eid)]
return rst return rst
...@@ -4519,17 +4528,14 @@ def find_src_dst_ntypes(ntypes, metagraph): ...@@ -4519,17 +4528,14 @@ def find_src_dst_ntypes(ntypes, metagraph):
a dictionary from type name to type id. Return None if the graph is a dictionary from type name to type id. Return None if the graph is
not uni-bipartite. not uni-bipartite.
""" """
src, dst, _ = metagraph.edges() ret = _CAPI_DGLFindSrcDstNtypes(metagraph)
if set(src.tonumpy()).isdisjoint(set(dst.tonumpy())): if ret is None:
srctypes = {ntypes[tid] : tid for tid in src}
dsttypes = {ntypes[tid] : tid for tid in dst}
# handle isolated node types
for ntid, ntype in enumerate(ntypes):
if ntype not in srctypes and ntype not in dsttypes:
srctypes[ntype] = ntid
return srctypes, dsttypes
else:
return None return None
else:
src, dst = ret
srctypes = {ntypes[tid.data] : tid.data for tid in src}
dsttypes = {ntypes[tid.data] : tid.data for tid in dst}
return srctypes, dsttypes
def infer_ntype_from_dict(graph, etype_dict): def infer_ntype_from_dict(graph, etype_dict):
"""Infer node type from dictionary of edge type to values. """Infer node type from dictionary of edge type to values.
...@@ -4784,3 +4790,5 @@ def check_idtype_dict(graph_dtype, tensor_dict): ...@@ -4784,3 +4790,5 @@ def check_idtype_dict(graph_dtype, tensor_dict):
"""check whether the dtypes of tensors in dict are consistent with graph's dtype""" """check whether the dtypes of tensors in dict are consistent with graph's dtype"""
for _, v in tensor_dict.items(): for _, v in tensor_dict.items():
check_same_dtype(graph_dtype, v) check_same_dtype(graph_dtype, v)
_init_api("dgl.heterograph")
...@@ -1166,45 +1166,53 @@ class FlattenedHeteroGraph(ObjectBase): ...@@ -1166,45 +1166,53 @@ class FlattenedHeteroGraph(ObjectBase):
class HeteroPickleStates(ObjectBase): class HeteroPickleStates(ObjectBase):
"""Pickle states object class in C++ backend.""" """Pickle states object class in C++ backend."""
@property @property
def metagraph(self): def version(self):
"""Metagraph """Version number
Returns Returns
------- -------
GraphIndex int
Metagraph structure version number
""" """
return _CAPI_DGLHeteroPickleStatesGetMetagraph(self) return _CAPI_DGLHeteroPickleStatesGetVersion(self)
@property @property
def num_nodes_per_type(self): def meta(self):
"""Number of nodes per edge type """Meta info
Returns Returns
------- -------
Tensor bytearray
Array of number of nodes for each type Serialized meta info
""" """
return F.zerocopy_from_dgl_ndarray(_CAPI_DGLHeteroPickleStatesGetNumVertices(self)) return bytearray(_CAPI_DGLHeteroPickleStatesGetMeta(self))
@property @property
def adjs(self): def arrays(self):
"""Adjacency matrices of all the relation graphs """Arrays representing the graph structure (COO or CSR)
Returns Returns
------- -------
list of dgl.ndarray.SparseMatrix list of dgl.ndarray.NDArray
Adjacency matrices Arrays
""" """
return list(_CAPI_DGLHeteroPickleStatesGetAdjs(self)) num_arr = _CAPI_DGLHeteroPickleStatesGetArraysNum(self)
arr_func = _CAPI_DGLHeteroPickleStatesGetArrays(self)
return [arr_func(i) for i in range(num_arr)]
def __getstate__(self): def __getstate__(self):
return self.metagraph, self.num_nodes_per_type, self.adjs arrays = [F.zerocopy_from_dgl_ndarray(arr) for arr in self.arrays]
return self.version, self.meta, arrays
def __setstate__(self, state): def __setstate__(self, state):
metagraph, num_nodes_per_type, adjs = state if isinstance(state[0], int):
num_nodes_per_type = F.zerocopy_to_dgl_ndarray(num_nodes_per_type) _, meta, arrays = state
self.__init_handle_by_constructor__( arrays = [F.zerocopy_to_dgl_ndarray(arr) for arr in arrays]
_CAPI_DGLCreateHeteroPickleStates, metagraph, num_nodes_per_type, adjs) self.__init_handle_by_constructor__(
_CAPI_DGLCreateHeteroPickleStates, meta, arrays)
else:
metagraph, num_nodes_per_type, adjs = state
num_nodes_per_type = F.zerocopy_to_dgl_ndarray(num_nodes_per_type)
self.__init_handle_by_constructor__(
_CAPI_DGLCreateHeteroPickleStatesOld, metagraph, num_nodes_per_type, adjs)
_init_api("dgl.heterograph_index") _init_api("dgl.heterograph_index")
...@@ -531,4 +531,30 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsImmutableGraph") ...@@ -531,4 +531,30 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsImmutableGraph")
*rv = GraphRef(hg->AsImmutableGraph()); *rv = GraphRef(hg->AsImmutableGraph());
}); });
DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef metagraph = args[0];
std::set<int64_t> dst_set;
std::set<int64_t> src_set;
for (int64_t eid = 0; eid < metagraph->NumEdges(); ++eid) {
auto edge = metagraph->FindEdge(eid);
auto src = edge.first;
auto dst = edge.second;
dst_set.insert(dst);
if (dst_set.count(src))
return;
}
List<Value> srclist, dstlist;
List<List<Value>> ret_list;
for (auto dst : dst_set)
dstlist.push_back(Value(MakeValue(dst)));
for (int64_t nid = 0 ; nid < metagraph->NumVertices(); ++nid)
if (!dst_set.count(nid))
srclist.push_back(Value(MakeValue(nid)));
ret_list.push_back(srclist);
ret_list.push_back(dstlist);
*rv = ret_list;
});
} // namespace dgl } // namespace dgl
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
*/ */
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/immutable_graph.h>
#include <dgl/graph_serializer.h>
#include <dmlc/memory_io.h>
#include "./heterograph.h" #include "./heterograph.h"
#include "../c_api_common.h" #include "../c_api_common.h"
...@@ -14,20 +17,32 @@ namespace dgl { ...@@ -14,20 +17,32 @@ namespace dgl {
HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) { HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
HeteroPickleStates states; HeteroPickleStates states;
states.metagraph = graph->meta_graph(); dmlc::MemoryStringStream ofs(&states.meta);
states.num_nodes_per_type = graph->NumVerticesPerType(); dmlc::Stream *strm = &ofs;
states.adjs.resize(graph->NumEdgeTypes()); strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));
strm->Write(graph->NumVerticesPerType());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
SparseFormat fmt = graph->SelectFormat(etype, SparseFormat::kAny); SparseFormat fmt = graph->SelectFormat(etype, SparseFormat::kAny);
states.adjs[etype] = std::make_shared<SparseMatrix>();
switch (fmt) { switch (fmt) {
case SparseFormat::kCOO: case SparseFormat::kCOO: {
*states.adjs[etype] = graph->GetCOOMatrix(etype).ToSparseMatrix(); strm->Write(SparseFormat::kCOO);
const auto &coo = graph->GetCOOMatrix(etype);
strm->Write(coo.row_sorted);
strm->Write(coo.col_sorted);
states.arrays.push_back(coo.row);
states.arrays.push_back(coo.col);
break; break;
}
case SparseFormat::kCSR: case SparseFormat::kCSR:
case SparseFormat::kCSC: case SparseFormat::kCSC: {
*states.adjs[etype] = graph->GetCSRMatrix(etype).ToSparseMatrix(); strm->Write(SparseFormat::kCSR);
const auto &csr = graph->GetCSRMatrix(etype);
strm->Write(csr.sorted);
states.arrays.push_back(csr.indptr);
states.arrays.push_back(csr.indices);
states.arrays.push_back(csr.data);
break; break;
}
default: default:
LOG(FATAL) << "Unsupported sparse format."; LOG(FATAL) << "Unsupported sparse format.";
} }
...@@ -36,6 +51,62 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) { ...@@ -36,6 +51,62 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
} }
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
char *buf = const_cast<char *>(states.meta.c_str()); // a readonly stream?
dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
dmlc::Stream *strm = &ifs;
auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
CHECK(strm->Read(&meta_imgraph)) << "Invalid meta graph";
GraphPtr metagraph = meta_imgraph;
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
std::vector<int64_t> num_nodes_per_type;
CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type";
auto array_itr = states.arrays.begin();
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
const auto& pair = metagraph->FindEdge(etype);
const dgl_type_t srctype = pair.first;
const dgl_type_t dsttype = pair.second;
const int64_t num_vtypes = (srctype == dsttype)? 1 : 2;
int64_t num_src = num_nodes_per_type[srctype];
int64_t num_dst = num_nodes_per_type[dsttype];
SparseFormat fmt;
CHECK(strm->Read(&fmt)) << "Invalid SparseFormat";
HeteroGraphPtr relgraph;
switch (fmt) {
case SparseFormat::kCOO: {
CHECK_GE(states.arrays.end() - array_itr, 2);
const auto &row = *(array_itr++);
const auto &col = *(array_itr++);
bool rsorted;
bool csorted;
CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
auto coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
relgraph = CreateFromCOO(num_vtypes, coo);
break;
}
case SparseFormat::kCSR: {
CHECK_GE(states.arrays.end() - array_itr, 3);
const auto &indptr = *(array_itr++);
const auto &indices = *(array_itr++);
const auto &edge_id = *(array_itr++);
bool sorted;
CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
auto csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
relgraph = CreateFromCSR(num_vtypes, csr);
break;
}
case SparseFormat::kCSC:
default:
LOG(FATAL) << "Unsupported sparse format.";
}
relgraphs[etype] = relgraph;
}
return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
}
// For backward compatibility
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states) {
const auto metagraph = states.metagraph; const auto metagraph = states.metagraph;
const auto &num_nodes_per_type = states.num_nodes_per_type; const auto &num_nodes_per_type = states.num_nodes_per_type;
CHECK_EQ(states.adjs.size(), metagraph->NumEdges()); CHECK_EQ(states.adjs.size(), metagraph->NumEdges());
...@@ -63,36 +134,44 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { ...@@ -63,36 +134,44 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type); return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
} }
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMetagraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef st = args[0];
*rv = st->version;
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMeta")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef st = args[0]; HeteroPickleStatesRef st = args[0];
*rv = GraphRef(st->metagraph); DGLByteArray buf;
buf.data = st->meta.c_str();
buf.size = st->meta.size();
*rv = buf;
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetNumVertices") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArrays")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef st = args[0]; HeteroPickleStatesRef st = args[0];
*rv = NDArray::FromVector(st->num_nodes_per_type); *rv = ConvertNDArrayVectorToPackedFunc(st->arrays);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetAdjs") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArraysNum")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef st = args[0]; HeteroPickleStatesRef st = args[0];
std::vector<SparseMatrixRef> refs(st->adjs.begin(), st->adjs.end()); *rv = static_cast<int64_t>(st->arrays.size());
*rv = List<SparseMatrixRef>(refs);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef metagraph = args[0]; std::string meta = args[0];
IdArray num_nodes_per_type = args[1]; const List<Value> arrays = args[1];
List<SparseMatrixRef> adjs = args[2];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates ); std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
st->metagraph = metagraph.sptr(); st->version = 1;
st->num_nodes_per_type = num_nodes_per_type.ToVector<int64_t>(); st->meta = meta;
st->adjs.reserve(adjs.size()); st->arrays.reserve(arrays.size());
for (const auto& ref : adjs) for (const auto& ref : arrays) {
st->adjs.push_back(ref.sptr()); st->arrays.push_back(ref->data);
}
*rv = HeteroPickleStatesRef(st); *rv = HeteroPickleStatesRef(st);
}); });
...@@ -107,8 +186,32 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle") ...@@ -107,8 +186,32 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef ref = args[0]; HeteroPickleStatesRef ref = args[0];
HeteroGraphPtr graph = HeteroUnpickle(*ref.sptr()); HeteroGraphPtr graph;
switch (ref->version) {
case 0:
graph = HeteroUnpickleOld(*ref.sptr());
break;
case 1:
graph = HeteroUnpickle(*ref.sptr());
break;
default:
LOG(FATAL) << "Version can only be 0 or 1.";
}
*rv = HeteroGraphRef(graph); *rv = HeteroGraphRef(graph);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef metagraph = args[0];
IdArray num_nodes_per_type = args[1];
List<SparseMatrixRef> adjs = args[2];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
st->version = 0;
st->metagraph = metagraph.sptr();
st->num_nodes_per_type = num_nodes_per_type.ToVector<int64_t>();
st->adjs.reserve(adjs.size());
for (const auto& ref : adjs)
st->adjs.push_back(ref.sptr());
*rv = HeteroPickleStatesRef(st);
});
} // namespace dgl } // namespace dgl
...@@ -9,6 +9,7 @@ import backend as F ...@@ -9,6 +9,7 @@ import backend as F
import dgl.function as fn import dgl.function as fn
import pickle import pickle
import io import io
import unittest
def _assert_is_identical(g, g2): def _assert_is_identical(g, g2):
assert g.is_readonly == g2.is_readonly assert g.is_readonly == g2.is_readonly
...@@ -257,6 +258,26 @@ def test_pickling_heterograph(): ...@@ -257,6 +258,26 @@ def test_pickling_heterograph():
new_g = _reconstruct_pickle(g) new_g = _reconstruct_pickle(g)
_assert_is_identical_hetero(g, new_g) _assert_is_identical_hetero(g, new_g)
@unittest.skipIf(dgl.backend.backend_name != "pytorch", reason="Only test for pytorch format file")
def test_pickling_heterograph_index_compatibility():
plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
wishes_nx = nx.DiGraph()
wishes_nx.add_nodes_from(['u0', 'u1', 'u2'], bipartite=0)
wishes_nx.add_nodes_from(['g0', 'g1'], bipartite=1)
wishes_nx.add_edge('u0', 'g1', id=0)
wishes_nx.add_edge('u2', 'g0', id=1)
follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
plays_g = dgl.bipartite(plays_spmat, 'user', 'plays', 'game')
wishes_g = dgl.bipartite(wishes_nx, 'user', 'wishes', 'game')
develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])
with open("tests/compute/hetero_pickle_old.pkl", "rb") as f:
gi = pickle.load(f)
f.close()
new_g = dgl.DGLHeteroGraph(gi, g.ntypes, g.etypes)
_assert_is_identical_hetero(g, new_g)
if __name__ == '__main__': if __name__ == '__main__':
test_pickling_index() test_pickling_index()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment