Unverified Commit fa0ee46a authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] node id validity check (#1073)

* fix

* improve

* fix lint

* upd

* fix

* upd
parent bc4f4352
......@@ -21,7 +21,7 @@ __all__ = [
'to_networkx',
]
def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
def graph(data, ntype='_N', etype='_E', card=None, validate=False, **kwargs):
"""Create a graph with one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
......@@ -45,6 +45,10 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
card : int, optional
Cardinality (number of nodes in the graph). If None, infer from input data, i.e.
the largest node ID plus 1. (Default: None)
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
some time.
If False and card is not None, user would receive a warning. (Default: False)
kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of:
......@@ -101,6 +105,16 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
['follows']
>>> g.canonical_etypes
[('user', 'follows', 'user')]
Check if node ids are within cardinality
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=2, validate=True)
...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=3, validate=True)
Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
edata_schemes={})
"""
if card is not None:
urange, vrange = card, card
......@@ -108,9 +122,9 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
urange, vrange = None, None
if isinstance(data, tuple):
u, v = data
return create_from_edges(u, v, ntype, etype, ntype, urange, vrange)
return create_from_edges(u, v, ntype, etype, ntype, urange, vrange, validate)
elif isinstance(data, list):
return create_from_edge_list(data, ntype, etype, ntype, urange, vrange)
return create_from_edge_list(data, ntype, etype, ntype, urange, vrange, validate)
elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy(data, ntype, etype, ntype)
elif isinstance(data, nx.Graph):
......@@ -118,7 +132,7 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
else:
raise DGLError('Unsupported graph data type:', type(data))
def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=False, **kwargs):
"""Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes
......@@ -147,6 +161,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
card : pair of int, optional
Cardinality (number of nodes in the source and destination group). If None,
infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None)
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
some time.
If False and card is not None, user would receive a warning. (Default: False)
kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of:
......@@ -215,6 +233,16 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
4
>>> g.edges()
(tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))
Check if node ids are within cardinality
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(2, 4), validate=True)
...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(3, 4), validate=True)
>>> g
Graph(num_nodes={'_U': 3, '_V': 4},
num_edges={('_U', '_E', '_V'): 3},
metagraph=[('_U', '_V')])
"""
if utype == vtype:
raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.')
......@@ -224,9 +252,9 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
urange, vrange = None, None
if isinstance(data, tuple):
u, v = data
return create_from_edges(u, v, utype, etype, vtype, urange, vrange)
return create_from_edges(u, v, utype, etype, vtype, urange, vrange, validate)
elif isinstance(data, list):
return create_from_edge_list(data, utype, etype, vtype, urange, vrange)
return create_from_edge_list(data, utype, etype, vtype, urange, vrange, validate)
elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy(data, utype, etype, vtype)
elif isinstance(data, nx.Graph):
......@@ -667,7 +695,7 @@ def to_homo(G):
# Internal APIs
############################################################
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=False):
"""Internal function to create a graph from incident nodes with types.
utype could be equal to vtype
......@@ -690,6 +718,8 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
vrange : int, optional
The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
Returns
-------
......@@ -697,6 +727,13 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
"""
u = utils.toindex(u)
v = utils.toindex(v)
if validate:
if urange is not None and urange <= int(F.asnumpy(F.max(u.tousertensor(), dim=0))):
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
urange, int(F.asnumpy(F.max(u.tousertensor(), dim=0)))))
if vrange is not None and vrange <= int(F.asnumpy(F.max(v.tousertensor(), dim=0))):
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
vrange, int(F.asnumpy(F.max(v.tousertensor(), dim=0)))))
urange = urange or (int(F.asnumpy(F.max(u.tousertensor(), dim=0))) + 1)
vrange = vrange or (int(F.asnumpy(F.max(v.tousertensor(), dim=0))) + 1)
if utype == vtype:
......@@ -710,7 +747,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
else:
return DGLHeteroGraph(hgidx, [utype, vtype], [etype])
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, validate=False):
"""Internal function to create a heterograph from a list of edge tuples with types.
utype could be equal to vtype
......@@ -731,6 +768,9 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
vrange : int, optional
The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
Returns
-------
......@@ -742,7 +782,7 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
u, v = zip(*elist)
u = list(u)
v = list(v)
return create_from_edges(u, v, utype, etype, vtype, urange, vrange)
return create_from_edges(u, v, utype, etype, vtype, urange, vrange, validate)
def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False):
"""Internal function to create a heterograph from a scipy sparse matrix with types.
......@@ -762,6 +802,9 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False):
If True, the entries in the sparse matrix are treated as edge IDs.
Otherwise, the entries are ignored and edges will be added in
(source, destination) order.
validate : bool, optional
If True, checks if node IDs are within range.
Returns
-------
......
......@@ -6,6 +6,8 @@ import scipy.sparse as ssp
import itertools
import backend as F
import networkx as nx
from dgl import DGLError
def create_test_heterograph():
# test heterograph from the docstring, plus a user -- wishes -- game relation
......@@ -93,6 +95,36 @@ def test_create():
assert g.number_of_nodes('l1') == 3
assert g.number_of_nodes('l2') == 4
# test if validate flag works
# homo graph
fail = False
try:
g = dgl.graph(
([0, 0, 0, 1, 1, 2], [0, 1, 2, 0, 1, 2]),
card=2,
validate=True
)
except DGLError:
fail = True
finally:
assert fail, "should catch a DGLError because node ID is out of bound."
# bipartite graph
def _test_validate_bipartite(card):
fail = False
try:
g = dgl.bipartite(
([0, 0, 1, 1, 2], [1, 1, 2, 2, 3]),
card=card,
validate=True
)
except DGLError:
fail = True
finally:
assert fail, "should catch a DGLError because node ID is out of bound."
_test_validate_bipartite((3, 3))
_test_validate_bipartite((2, 4))
def test_query():
g = create_test_heterograph()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment