Commit 063ed69c authored by Da Zheng's avatar Da Zheng Committed by Minjie Wang
Browse files

[BUGFIX] fix a bug in creating immutable graph index. (#251)

* fix a bug in creating immutable graph index.

* fix for new changes in the backend API.

* fix for creating immutable graph index from coo matrix.

* retrigger
parent 2194b7df
...@@ -345,8 +345,9 @@ class ImmutableGraphIndex(object): ...@@ -345,8 +345,9 @@ class ImmutableGraphIndex(object):
src = mx.nd.array(out_coo.row, dtype=np.int64) src = mx.nd.array(out_coo.row, dtype=np.int64)
dst = mx.nd.array(out_coo.col, dtype=np.int64) dst = mx.nd.array(out_coo.col, dtype=np.int64)
# TODO we can't generate a csr_matrix with np.int64 directly. # TODO we can't generate a csr_matrix with np.int64 directly.
self.__init__(mx.nd.sparse.csr_matrix((edge_ids, (dst, src)), shape=out_coo.shape).astype(np.int64), size = max(out_coo.shape)
mx.nd.sparse.csr_matrix((edge_ids, (src, dst)), shape=out_coo.shape).astype(np.int64)) self.__init__(mx.nd.sparse.csr_matrix((edge_ids, (dst, src)), shape=(size, size)).astype(np.int64),
mx.nd.sparse.csr_matrix((edge_ids, (src, dst)), shape=(size, size)).astype(np.int64))
def create_immutable_graph_index(in_csr=None, out_csr=None): def create_immutable_graph_index(in_csr=None, out_csr=None):
""" Create an empty backend-specific immutable graph index. """ Create an empty backend-specific immutable graph index.
......
...@@ -582,12 +582,42 @@ class ImmutableGraphIndex(object): ...@@ -582,12 +582,42 @@ class ImmutableGraphIndex(object):
nx_graph : networkx.DiGraph nx_graph : networkx.DiGraph
The nx graph The nx graph
""" """
if not isinstance(nx_graph, nx.DiGraph): if not isinstance(nx_graph, nx.Graph):
nx_graph = nx.DiGraph(nx_graph) nx_graph = (nx.MultiDiGraph(nx_graph) if self.is_multigraph()
else nx.DiGraph(nx_graph))
else:
nx_graph = nx_graph.to_directed()
assert nx_graph.number_of_edges() > 0, "can't create an empty immutable graph"
# nx_graph.edges(data=True) returns src, dst, attr_dict
has_edge_id = 'id' in next(iter(nx_graph.edges(data=True)))[-1]
if has_edge_id:
num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64)
for u, v, attr in nx_graph.edges(data=True):
eid = attr['id']
src[eid] = u
dst[eid] = v
else:
src = []
dst = []
for e in nx_graph.edges:
src.append(e[0])
dst.append(e[1])
eid = np.arange(0, len(src), dtype=np.int64)
num_nodes = nx_graph.number_of_nodes()
# We store edge Ids as an edge attribute. # We store edge Ids as an edge attribute.
nodelist = list(range(nx_graph.number_of_nodes())) eid = F.tensor(eid, dtype=np.int32)
out_mat = nx.convert_matrix.to_scipy_sparse_matrix(nx_graph, nodelist=nodelist, format='coo') src = F.tensor(src, dtype=np.int64)
self._sparse.from_coo_matrix(out_mat) dst = F.tensor(dst, dtype=np.int64)
out_csr, _ = F.sparse_matrix(eid, ('coo', (src, dst)), (num_nodes, num_nodes))
in_csr, _ = F.sparse_matrix(eid, ('coo', (dst, src)), (num_nodes, num_nodes))
out_csr = out_csr.astype(np.int64)
in_csr = in_csr.astype(np.int64)
self._sparse = F.create_immutable_graph_index(in_csr, out_csr)
def from_scipy_sparse_matrix(self, adj): def from_scipy_sparse_matrix(self, adj):
"""Convert from scipy sparse matrix. """Convert from scipy sparse matrix.
......
...@@ -3,6 +3,7 @@ os.environ['DGLBACKEND'] = 'mxnet' ...@@ -3,6 +3,7 @@ os.environ['DGLBACKEND'] = 'mxnet'
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
import scipy as sp import scipy as sp
import dgl
from dgl.graph import GraphIndex, create_graph_index from dgl.graph import GraphIndex, create_graph_index
from dgl.graph_index import map_to_subgraph_nid from dgl.graph_index import map_to_subgraph_nid
from dgl import utils from dgl import utils
...@@ -99,8 +100,23 @@ def test_node_subgraph(): ...@@ -99,8 +100,23 @@ def test_node_subgraph():
for i in range(4): for i in range(4):
check_graph_equal(subgs[i], subigs[i]) check_graph_equal(subgs[i], subigs[i])
def test_create_graph():
elist = [(1, 2), (0, 1), (0, 2)]
ig = dgl.DGLGraph(elist, readonly=True)
g = dgl.DGLGraph(elist, readonly=False)
for edge in elist:
assert g.edge_id(edge[0], edge[1]) == ig.edge_id(edge[0], edge[1])
data = [1, 2, 3]
rows = [1, 0, 0]
cols = [2, 1, 2]
mat = sp.sparse.coo_matrix((data, (rows, cols)))
ig = dgl.DGLGraph(mat, readonly=True)
for edge in elist:
assert g.edge_id(edge[0], edge[1]) == ig.edge_id(edge[0], edge[1])
if __name__ == '__main__': if __name__ == '__main__':
test_basics() test_basics()
test_graph_gen() test_graph_gen()
test_node_subgraph() test_node_subgraph()
test_create_graph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment