Unverified Commit 220a1e68 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[BUG] Fixing bug in pickling readonly graphs (#397)

* fixing bug in pickling readonly graphs

* multigraph test coverage

* fixing lint (??????)

* another lint

* is_readonly interface
parent 788d8dd4
...@@ -422,6 +422,12 @@ class DGLGraph(object): ...@@ -422,6 +422,12 @@ class DGLGraph(object):
""" """
return self._graph.is_multigraph() return self._graph.is_multigraph()
@property
def is_readonly(self):
"""True if the graph is readonly, False otherwise.
"""
return self._graph.is_readonly()
def number_of_edges(self): def number_of_edges(self):
"""Return the number of edges in the graph. """Return the number of edges in the graph.
......
...@@ -30,6 +30,7 @@ class GraphIndex(object): ...@@ -30,6 +30,7 @@ class GraphIndex(object):
def __del__(self): def __del__(self):
"""Free this graph index object.""" """Free this graph index object."""
if hasattr(self, '_handle'):
_CAPI_DGLGraphFree(self._handle) _CAPI_DGLGraphFree(self._handle)
def __getstate__(self): def __getstate__(self):
...@@ -46,19 +47,19 @@ class GraphIndex(object): ...@@ -46,19 +47,19 @@ class GraphIndex(object):
""" """
n_nodes, multigraph, readonly, src, dst = state n_nodes, multigraph, readonly, src, dst = state
if readonly: self._cache = {}
self._readonly = readonly
self._multigraph = multigraph self._multigraph = multigraph
self.init(src, dst, F.arange(0, len(src)), n_nodes) self._readonly = readonly
if readonly:
self._init(src, dst, utils.toindex(F.arange(0, len(src))), n_nodes)
else: else:
self._handle = _CAPI_DGLGraphCreateMutable(multigraph) self._handle = _CAPI_DGLGraphCreateMutable(multigraph)
self._cache = {}
self.clear() self.clear()
self.add_nodes(n_nodes) self.add_nodes(n_nodes)
self.add_edges(src, dst) self.add_edges(src, dst)
def init(self, src_ids, dst_ids, edge_ids, num_nodes): def _init(self, src_ids, dst_ids, edge_ids, num_nodes):
"""The actual init function""" """The actual init function"""
assert len(src_ids) == len(dst_ids) assert len(src_ids) == len(dst_ids)
assert len(src_ids) == len(edge_ids) assert len(src_ids) == len(edge_ids)
...@@ -746,7 +747,7 @@ class GraphIndex(object): ...@@ -746,7 +747,7 @@ class GraphIndex(object):
eid = utils.toindex(eid) eid = utils.toindex(eid)
src = utils.toindex(src) src = utils.toindex(src)
dst = utils.toindex(dst) dst = utils.toindex(dst)
self.init(src, dst, eid, num_nodes) self._init(src, dst, eid, num_nodes)
def from_scipy_sparse_matrix(self, adj): def from_scipy_sparse_matrix(self, adj):
...@@ -763,7 +764,7 @@ class GraphIndex(object): ...@@ -763,7 +764,7 @@ class GraphIndex(object):
src = utils.toindex(adj_coo.row) src = utils.toindex(adj_coo.row)
dst = utils.toindex(adj_coo.col) dst = utils.toindex(adj_coo.col)
edge_ids = utils.toindex(F.arange(0, len(adj_coo.row))) edge_ids = utils.toindex(F.arange(0, len(adj_coo.row)))
self.init(src, dst, edge_ids, num_nodes) self._init(src, dst, edge_ids, num_nodes)
def from_edge_list(self, elist): def from_edge_list(self, elist):
...@@ -786,7 +787,7 @@ class GraphIndex(object): ...@@ -786,7 +787,7 @@ class GraphIndex(object):
if min_nodes != 0: if min_nodes != 0:
raise DGLError('Invalid edge list. Nodes must start from 0.') raise DGLError('Invalid edge list. Nodes must start from 0.')
edge_ids = utils.toindex(F.arange(0, len(src))) edge_ids = utils.toindex(F.arange(0, len(src)))
self.init(src_ids, dst_ids, edge_ids, num_nodes) self._init(src_ids, dst_ids, edge_ids, num_nodes)
def line_graph(self, backtracking=True): def line_graph(self, backtracking=True):
"""Return the line graph of this graph. """Return the line graph of this graph.
......
...@@ -20,8 +20,10 @@ Graph::Graph(IdArray src_ids, IdArray dst_ids, IdArray edge_ids, size_t num_node ...@@ -20,8 +20,10 @@ Graph::Graph(IdArray src_ids, IdArray dst_ids, IdArray edge_ids, size_t num_node
CHECK(IsValidIdArray(edge_ids)); CHECK(IsValidIdArray(edge_ids));
this->AddVertices(num_nodes); this->AddVertices(num_nodes);
num_edges_ = src_ids->shape[0]; num_edges_ = src_ids->shape[0];
CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0]) << "vectors in COO must have the same length"; CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0])
CHECK(static_cast<int64_t>(num_edges_) == edge_ids->shape[0]) << "vectors in COO must have the same length"; << "vectors in COO must have the same length";
CHECK(static_cast<int64_t>(num_edges_) == edge_ids->shape[0])
<< "vectors in COO must have the same length";
const dgl_id_t *src_data = static_cast<dgl_id_t*>(src_ids->data); const dgl_id_t *src_data = static_cast<dgl_id_t*>(src_ids->data);
const dgl_id_t *dst_data = static_cast<dgl_id_t*>(dst_ids->data); const dgl_id_t *dst_data = static_cast<dgl_id_t*>(dst_ids->data);
const dgl_id_t *edge_data = static_cast<dgl_id_t*>(edge_ids->data); const dgl_id_t *edge_data = static_cast<dgl_id_t*>(edge_ids->data);
...@@ -507,7 +509,10 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const ...@@ -507,7 +509,10 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const
uint64_t num_edges = NumEdges(); uint64_t num_edges = NumEdges();
uint64_t num_nodes = NumVertices(); uint64_t num_nodes = NumVertices();
if (fmt == "coo") { if (fmt == "coo") {
IdArray idx = IdArray::Empty({2 * static_cast<int64_t>(num_edges)}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray idx = IdArray::Empty(
{2 * static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
int64_t *idx_data = static_cast<int64_t*>(idx->data); int64_t *idx_data = static_cast<int64_t*>(idx->data);
if (transpose) { if (transpose) {
std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data); std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data);
...@@ -516,17 +521,28 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const ...@@ -516,17 +521,28 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const
std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data); std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data);
std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data + num_edges); std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data + num_edges);
} }
IdArray eid = IdArray::Empty({static_cast<int64_t>(num_edges)}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray eid = IdArray::Empty(
{static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
int64_t *eid_data = static_cast<int64_t*>(eid->data); int64_t *eid_data = static_cast<int64_t*>(eid->data);
for (uint64_t eid = 0; eid < num_edges; ++eid) { for (uint64_t eid = 0; eid < num_edges; ++eid) {
eid_data[eid] = eid; eid_data[eid] = eid;
} }
return std::vector<IdArray>{idx, eid}; return std::vector<IdArray>{idx, eid};
} else if (fmt == "csr") { } else if (fmt == "csr") {
IdArray indptr = IdArray::Empty({static_cast<int64_t>(num_nodes) + 1}, DLDataType{kDLInt, 64, 1}, IdArray indptr = IdArray::Empty(
{static_cast<int64_t>(num_nodes) + 1},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
IdArray indices = IdArray::Empty(
{static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty(
{static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0}); DLContext{kDLCPU, 0});
IdArray indices = IdArray::Empty({static_cast<int64_t>(num_edges)}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({static_cast<int64_t>(num_edges)}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t *indptr_data = static_cast<int64_t*>(indptr->data); int64_t *indptr_data = static_cast<int64_t*>(indptr->data);
int64_t *indices_data = static_cast<int64_t*>(indices->data); int64_t *indices_data = static_cast<int64_t*>(indices->data);
int64_t *eid_data = static_cast<int64_t*>(eid->data); int64_t *eid_data = static_cast<int64_t*>(eid->data);
......
...@@ -60,6 +60,8 @@ def test_pickling_frame(): ...@@ -60,6 +60,8 @@ def test_pickling_frame():
def _assert_is_identical(g, g2): def _assert_is_identical(g, g2):
assert g.is_multigraph == g2.is_multigraph
assert g.is_readonly == g2.is_readonly
assert g.number_of_nodes() == g2.number_of_nodes() assert g.number_of_nodes() == g2.number_of_nodes()
src, dst = g.all_edges() src, dst = g.all_edges()
src2, dst2 = g2.all_edges() src2, dst2 = g2.all_edges()
...@@ -140,6 +142,21 @@ def test_pickling_graph(): ...@@ -140,6 +142,21 @@ def test_pickling_graph():
_assert_is_identical(g, new_g) _assert_is_identical(g, new_g)
_assert_is_identical(g2, new_g2) _assert_is_identical(g2, new_g2)
# readonly graph
g = dgl.DGLGraph([(0, 1), (1, 2)], readonly=True)
new_g = _reconstruct_pickle(g)
_assert_is_identical(g, new_g)
# multigraph
g = dgl.DGLGraph([(0, 1), (0, 1), (1, 2)], multigraph=True)
new_g = _reconstruct_pickle(g)
_assert_is_identical(g, new_g)
# readonly multigraph
g = dgl.DGLGraph([(0, 1), (0, 1), (1, 2)], multigraph=True, readonly=True)
new_g = _reconstruct_pickle(g)
_assert_is_identical(g, new_g)
if __name__ == '__main__': if __name__ == '__main__':
test_pickling_index() test_pickling_index()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment