"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ed41db8525b8a7d48fe130fe610da98e8a53d3b0"
Unverified Commit 2194b7df authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix] edge order is not preserved when converting from edge list (#262)

parent a1d50f0f
......@@ -34,6 +34,7 @@ def pytorch_unit_test(dev) {
def mxnet_unit_test(dev) {
withEnv(["DGL_LIBRARY_PATH=${env.WORKSPACE}/build", "PYTHONPATH=${env.WORKSPACE}/python"]) {
sh "python3 -m nose -v --with-xunit tests/mxnet"
sh "python3 -m nose -v --with-xunit tests/graph_index"
}
}
......
......@@ -678,6 +678,25 @@ class GraphIndex(object):
dst = utils.toindex(adj_coo.col)
self.add_edges(src, dst)
def from_edge_list(self, elist):
"""Convert from an edge list.
Paramters
---------
elist : list
List of (u, v) edge tuple.
"""
self.clear()
src, dst = zip(*elist)
src = np.array(src)
dst = np.array(dst)
num_nodes = max(src.max(), dst.max()) + 1
min_nodes = min(src.min(), dst.min())
if min_nodes != 0:
raise DGLError('Invalid edge list. Nodes must start from 0.')
self.add_nodes(num_nodes)
self.add_edges(utils.toindex(src), utils.toindex(dst))
def line_graph(self, backtracking=True):
"""Return the line graph of this graph.
......@@ -868,7 +887,10 @@ def create_graph_index(graph_data=None, multigraph=False, readonly=False):
return graph_data
if readonly and graph_data is not None:
gi = create_immutable_graph_index(graph_data)
try:
gi = create_immutable_graph_index(graph_data)
except:
gi = None
# If we can't create an immutable graph index, we'll have to fall back.
if gi is not None:
return gi
......@@ -879,19 +901,27 @@ def create_graph_index(graph_data=None, multigraph=False, readonly=False):
if graph_data is None:
return gi
# edge list
if isinstance(graph_data, (list, tuple)):
try:
gi.from_edge_list(graph_data)
return gi
except:
raise DGLError('Graph data is not a valid edge list.')
# scipy format
if isinstance(graph_data, scipy.sparse.spmatrix):
try:
gi.from_scipy_sparse_matrix(graph_data)
return gi
except:
raise Exception('Graph data is not a valid scipy sparse matrix.')
raise DGLError('Graph data is not a valid scipy sparse matrix.')
# networkx - any format
try:
gi.from_networkx(graph_data)
except:
raise Exception('Error while creating graph from input of type "%s".'
raise DGLError('Error while creating graph from input of type "%s".'
% type(graph_data))
return gi
......
......@@ -8,7 +8,7 @@ import scipy.sparse as sp
from ._ffi.function import _init_api
from . import backend as F
from . import utils
from .base import ALL, is_all, dgl_warning
from .base import ALL, is_all, dgl_warning, DGLError
class ImmutableGraphIndex(object):
"""Graph index object on immutable graphs.
......@@ -33,7 +33,7 @@ class ImmutableGraphIndex(object):
num : int
Number of nodes to be added.
"""
raise Exception('Immutable graph doesn\'t support adding nodes')
raise DGLError('Immutable graph doesn\'t support adding nodes')
def add_edge(self, u, v):
"""Add one edge.
......@@ -45,7 +45,7 @@ class ImmutableGraphIndex(object):
v : int
The dst node.
"""
raise Exception('Immutable graph doesn\'t support adding an edge')
raise DGLError('Immutable graph doesn\'t support adding an edge')
def add_edges(self, u, v):
"""Add many edges.
......@@ -57,11 +57,11 @@ class ImmutableGraphIndex(object):
v : utils.Index
The dst nodes.
"""
raise Exception('Immutable graph doesn\'t support adding edges')
raise DGLError('Immutable graph doesn\'t support adding edges')
def clear(self):
"""Clear the graph."""
raise Exception('Immutable graph doesn\'t support clearing up')
raise DGLError('Immutable graph doesn\'t support clearing up')
def is_multigraph(self):
"""Return whether the graph is a multigraph
......@@ -592,6 +592,8 @@ class ImmutableGraphIndex(object):
def from_scipy_sparse_matrix(self, adj):
"""Convert from scipy sparse matrix.
NOTE: we assume the row is src nodes and the col is dst nodes.
Parameters
----------
adj : scipy sparse matrix
......@@ -601,6 +603,26 @@ class ImmutableGraphIndex(object):
out_mat = adj.tocoo()
self._sparse.from_coo_matrix(out_mat)
def from_edge_list(self, elist):
"""Convert from an edge list.
Paramters
---------
elist : list
List of (u, v) edge tuple.
"""
self.clear()
src, dst = zip(*elist)
src = np.array(src)
dst = np.array(dst)
num_nodes = max(src.max(), dst.max()) + 1
min_nodes = min(src.min(), dst.min())
if min_nodes != 0:
raise DGLError('Invalid edge list. Nodes must start from 0.')
data = np.ones((len(src),), dtype=np.int32)
spmat = sp.coo_matrix((data, (src, dst)), shape=(num_nodes, num_nodes))
self._sparse.from_coo_matrix(spmat)
def line_graph(self, backtracking=True):
"""Return the line graph of this graph.
......@@ -728,19 +750,27 @@ def create_immutable_graph_index(graph_data=None):
# Let's create an empty graph index first.
gi = ImmutableGraphIndex(F.create_immutable_graph_index())
# edge list
if isinstance(graph_data, (list, tuple)):
try:
gi.from_edge_list(graph_data)
return gi
except:
raise DGLError('Graph data is not a valid edge list.')
# scipy format
if isinstance(graph_data, sp.spmatrix):
try:
gi.from_scipy_sparse_matrix(graph_data)
return gi
except:
raise Exception('Graph data is not a valid scipy sparse matrix.')
raise DGLError('Graph data is not a valid scipy sparse matrix.')
# networkx - any format
try:
gi.from_networkx(graph_data)
except:
raise Exception('Error while creating graph from input of type "%s".'
raise DGLError('Error while creating graph from input of type "%s".'
% type(graph_data))
return gi
......
......@@ -139,8 +139,19 @@ def test_predsucc():
assert 2 in succ
assert 0 in succ
def test_create_from_elist():
elist = [(2, 1), (1, 0), (2, 0), (3, 0), (0, 2)]
g = create_graph_index(elist)
for i, (u, v) in enumerate(elist):
assert g.edge_id(u, v)[0] == i
# immutable graph
g = create_graph_index(elist, readonly=True)
for i, (u, v) in enumerate(elist):
print(u, v, g.edge_id(u, v)[0])
assert g.edge_id(u, v)[0] == i
if __name__ == '__main__':
test_edge_id()
test_nx()
test_predsucc()
test_create_from_elist()
......@@ -2,6 +2,7 @@ import time
import math
import numpy as np
import scipy.sparse as sp
import networkx as nx
import torch as th
import dgl
import utils as U
......@@ -26,6 +27,16 @@ def test_graph_creation():
g.ndata['h'] = 3 * th.ones((5, 2))
assert U.allclose(3 * th.ones((5, 2)), g.ndata['h'])
def test_create_from_elist():
elist = [(2, 1), (1, 0), (2, 0), (3, 0), (0, 2)]
g = dgl.DGLGraph(elist)
for i, (u, v) in enumerate(elist):
assert g.edge_id(u, v) == i
# immutable graph
g = dgl.DGLGraph(elist, readonly=True)
for i, (u, v) in enumerate(elist):
assert g.edge_id(u, v) == i
def test_adjmat_speed():
n = 1000
p = 10 * math.log(n) / n
......@@ -87,6 +98,7 @@ def test_incmat_speed():
if __name__ == '__main__':
test_graph_creation()
test_create_from_elist()
test_adjmat_speed()
test_incmat()
test_incmat_speed()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment