Commit 063ed69c authored by Da Zheng's avatar Da Zheng Committed by Minjie Wang
Browse files

[BUGFIX] fix a bug in creating immutable graph index. (#251)

* fix a bug in creating immutable graph index.

* fix for new changes in the backend API.

* fix for creating immutable graph index from coo matrix.

* retrigger
parent 2194b7df
......@@ -345,8 +345,9 @@ class ImmutableGraphIndex(object):
src = mx.nd.array(out_coo.row, dtype=np.int64)
dst = mx.nd.array(out_coo.col, dtype=np.int64)
# TODO we can't generate a csr_matrix with np.int64 directly.
self.__init__(mx.nd.sparse.csr_matrix((edge_ids, (dst, src)), shape=out_coo.shape).astype(np.int64),
mx.nd.sparse.csr_matrix((edge_ids, (src, dst)), shape=out_coo.shape).astype(np.int64))
size = max(out_coo.shape)
self.__init__(mx.nd.sparse.csr_matrix((edge_ids, (dst, src)), shape=(size, size)).astype(np.int64),
mx.nd.sparse.csr_matrix((edge_ids, (src, dst)), shape=(size, size)).astype(np.int64))
def create_immutable_graph_index(in_csr=None, out_csr=None):
""" Create an empty backend-specific immutable graph index.
......
......@@ -582,12 +582,42 @@ class ImmutableGraphIndex(object):
nx_graph : networkx.DiGraph
The nx graph
"""
if not isinstance(nx_graph, nx.DiGraph):
nx_graph = nx.DiGraph(nx_graph)
if not isinstance(nx_graph, nx.Graph):
nx_graph = (nx.MultiDiGraph(nx_graph) if self.is_multigraph()
else nx.DiGraph(nx_graph))
else:
nx_graph = nx_graph.to_directed()
assert nx_graph.number_of_edges() > 0, "can't create an empty immutable graph"
# nx_graph.edges(data=True) returns src, dst, attr_dict
has_edge_id = 'id' in next(iter(nx_graph.edges(data=True)))[-1]
if has_edge_id:
num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64)
for u, v, attr in nx_graph.edges(data=True):
eid = attr['id']
src[eid] = u
dst[eid] = v
else:
src = []
dst = []
for e in nx_graph.edges:
src.append(e[0])
dst.append(e[1])
eid = np.arange(0, len(src), dtype=np.int64)
num_nodes = nx_graph.number_of_nodes()
# We store edge Ids as an edge attribute.
nodelist = list(range(nx_graph.number_of_nodes()))
out_mat = nx.convert_matrix.to_scipy_sparse_matrix(nx_graph, nodelist=nodelist, format='coo')
self._sparse.from_coo_matrix(out_mat)
eid = F.tensor(eid, dtype=np.int32)
src = F.tensor(src, dtype=np.int64)
dst = F.tensor(dst, dtype=np.int64)
out_csr, _ = F.sparse_matrix(eid, ('coo', (src, dst)), (num_nodes, num_nodes))
in_csr, _ = F.sparse_matrix(eid, ('coo', (dst, src)), (num_nodes, num_nodes))
out_csr = out_csr.astype(np.int64)
in_csr = in_csr.astype(np.int64)
self._sparse = F.create_immutable_graph_index(in_csr, out_csr)
def from_scipy_sparse_matrix(self, adj):
"""Convert from scipy sparse matrix.
......
......@@ -3,6 +3,7 @@ os.environ['DGLBACKEND'] = 'mxnet'
import mxnet as mx
import numpy as np
import scipy as sp
import dgl
from dgl.graph import GraphIndex, create_graph_index
from dgl.graph_index import map_to_subgraph_nid
from dgl import utils
......@@ -99,8 +100,23 @@ def test_node_subgraph():
for i in range(4):
check_graph_equal(subgs[i], subigs[i])
def test_create_graph():
elist = [(1, 2), (0, 1), (0, 2)]
ig = dgl.DGLGraph(elist, readonly=True)
g = dgl.DGLGraph(elist, readonly=False)
for edge in elist:
assert g.edge_id(edge[0], edge[1]) == ig.edge_id(edge[0], edge[1])
data = [1, 2, 3]
rows = [1, 0, 0]
cols = [2, 1, 2]
mat = sp.sparse.coo_matrix((data, (rows, cols)))
ig = dgl.DGLGraph(mat, readonly=True)
for edge in elist:
assert g.edge_id(edge[0], edge[1]) == ig.edge_id(edge[0], edge[1])
if __name__ == '__main__':
test_basics()
test_graph_gen()
test_node_subgraph()
test_create_graph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment