Unverified Commit fa0ee46a authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] node id validity check (#1073)

* fix

* improve

* fix lint

* upd

* fix

* upd
parent bc4f4352
...@@ -21,7 +21,7 @@ __all__ = [ ...@@ -21,7 +21,7 @@ __all__ = [
'to_networkx', 'to_networkx',
] ]
def graph(data, ntype='_N', etype='_E', card=None, **kwargs): def graph(data, ntype='_N', etype='_E', card=None, validate=False, **kwargs):
"""Create a graph with one type of nodes and edges. """Create a graph with one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph In the sparse matrix perspective, :func:`dgl.graph` creates a graph
...@@ -45,6 +45,10 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs): ...@@ -45,6 +45,10 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
card : int, optional card : int, optional
Cardinality (number of nodes in the graph). If None, infer from input data, i.e. Cardinality (number of nodes in the graph). If None, infer from input data, i.e.
the largest node ID plus 1. (Default: None) the largest node ID plus 1. (Default: None)
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
some time.
If False and card is not None, user would receive a warning. (Default: False)
kwargs : key-word arguments, optional kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of: graph. It can consist of:
...@@ -101,6 +105,16 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs): ...@@ -101,6 +105,16 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
['follows'] ['follows']
>>> g.canonical_etypes >>> g.canonical_etypes
[('user', 'follows', 'user')] [('user', 'follows', 'user')]
Check if node ids are within cardinality
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=2, validate=True)
...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=3, validate=True)
Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
edata_schemes={})
""" """
if card is not None: if card is not None:
urange, vrange = card, card urange, vrange = card, card
...@@ -108,9 +122,9 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs): ...@@ -108,9 +122,9 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
urange, vrange = None, None urange, vrange = None, None
if isinstance(data, tuple): if isinstance(data, tuple):
u, v = data u, v = data
return create_from_edges(u, v, ntype, etype, ntype, urange, vrange) return create_from_edges(u, v, ntype, etype, ntype, urange, vrange, validate)
elif isinstance(data, list): elif isinstance(data, list):
return create_from_edge_list(data, ntype, etype, ntype, urange, vrange) return create_from_edge_list(data, ntype, etype, ntype, urange, vrange, validate)
elif isinstance(data, sp.sparse.spmatrix): elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy(data, ntype, etype, ntype) return create_from_scipy(data, ntype, etype, ntype)
elif isinstance(data, nx.Graph): elif isinstance(data, nx.Graph):
...@@ -118,7 +132,7 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs): ...@@ -118,7 +132,7 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
else: else:
raise DGLError('Unsupported graph data type:', type(data)) raise DGLError('Unsupported graph data type:', type(data))
def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs): def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=False, **kwargs):
"""Create a bipartite graph. """Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes The result graph is directed and edges must be from ``utype`` nodes
...@@ -147,6 +161,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs): ...@@ -147,6 +161,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
card : pair of int, optional card : pair of int, optional
Cardinality (number of nodes in the source and destination group). If None, Cardinality (number of nodes in the source and destination group). If None,
infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None) infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None)
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
some time.
If False and card is not None, user would receive a warning. (Default: False)
kwargs : key-word arguments, optional kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of: graph. It can consist of:
...@@ -215,6 +233,16 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs): ...@@ -215,6 +233,16 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
4 4
>>> g.edges() >>> g.edges()
(tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3])) (tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))
Check if node ids are within cardinality
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(2, 4), validate=True)
...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(3, 4), validate=True)
>>> g
Graph(num_nodes={'_U': 3, '_V': 4},
num_edges={('_U', '_E', '_V'): 3},
metagraph=[('_U', '_V')])
""" """
if utype == vtype: if utype == vtype:
raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.') raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.')
...@@ -224,9 +252,9 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs): ...@@ -224,9 +252,9 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
urange, vrange = None, None urange, vrange = None, None
if isinstance(data, tuple): if isinstance(data, tuple):
u, v = data u, v = data
return create_from_edges(u, v, utype, etype, vtype, urange, vrange) return create_from_edges(u, v, utype, etype, vtype, urange, vrange, validate)
elif isinstance(data, list): elif isinstance(data, list):
return create_from_edge_list(data, utype, etype, vtype, urange, vrange) return create_from_edge_list(data, utype, etype, vtype, urange, vrange, validate)
elif isinstance(data, sp.sparse.spmatrix): elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy(data, utype, etype, vtype) return create_from_scipy(data, utype, etype, vtype)
elif isinstance(data, nx.Graph): elif isinstance(data, nx.Graph):
...@@ -667,7 +695,7 @@ def to_homo(G): ...@@ -667,7 +695,7 @@ def to_homo(G):
# Internal APIs # Internal APIs
############################################################ ############################################################
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None): def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=False):
"""Internal function to create a graph from incident nodes with types. """Internal function to create a graph from incident nodes with types.
utype could be equal to vtype utype could be equal to vtype
...@@ -690,6 +718,8 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None): ...@@ -690,6 +718,8 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
vrange : int, optional vrange : int, optional
The destination node ID range. If None, the value is the The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None) maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
Returns Returns
------- -------
...@@ -697,6 +727,13 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None): ...@@ -697,6 +727,13 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
""" """
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
if validate:
if urange is not None and urange <= int(F.asnumpy(F.max(u.tousertensor(), dim=0))):
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
urange, int(F.asnumpy(F.max(u.tousertensor(), dim=0)))))
if vrange is not None and vrange <= int(F.asnumpy(F.max(v.tousertensor(), dim=0))):
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
vrange, int(F.asnumpy(F.max(v.tousertensor(), dim=0)))))
urange = urange or (int(F.asnumpy(F.max(u.tousertensor(), dim=0))) + 1) urange = urange or (int(F.asnumpy(F.max(u.tousertensor(), dim=0))) + 1)
vrange = vrange or (int(F.asnumpy(F.max(v.tousertensor(), dim=0))) + 1) vrange = vrange or (int(F.asnumpy(F.max(v.tousertensor(), dim=0))) + 1)
if utype == vtype: if utype == vtype:
...@@ -710,7 +747,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None): ...@@ -710,7 +747,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
else: else:
return DGLHeteroGraph(hgidx, [utype, vtype], [etype]) return DGLHeteroGraph(hgidx, [utype, vtype], [etype])
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None): def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, validate=False):
"""Internal function to create a heterograph from a list of edge tuples with types. """Internal function to create a heterograph from a list of edge tuples with types.
utype could be equal to vtype utype could be equal to vtype
...@@ -731,6 +768,9 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None): ...@@ -731,6 +768,9 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
vrange : int, optional vrange : int, optional
The destination node ID range. If None, the value is the The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None) maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
Returns Returns
------- -------
...@@ -742,7 +782,7 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None): ...@@ -742,7 +782,7 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
u, v = zip(*elist) u, v = zip(*elist)
u = list(u) u = list(u)
v = list(v) v = list(v)
return create_from_edges(u, v, utype, etype, vtype, urange, vrange) return create_from_edges(u, v, utype, etype, vtype, urange, vrange, validate)
def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False): def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False):
"""Internal function to create a heterograph from a scipy sparse matrix with types. """Internal function to create a heterograph from a scipy sparse matrix with types.
...@@ -762,6 +802,9 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False): ...@@ -762,6 +802,9 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False):
If True, the entries in the sparse matrix are treated as edge IDs. If True, the entries in the sparse matrix are treated as edge IDs.
Otherwise, the entries are ignored and edges will be added in Otherwise, the entries are ignored and edges will be added in
(source, destination) order. (source, destination) order.
validate : bool, optional
If True, checks if node IDs are within range.
Returns Returns
------- -------
......
...@@ -6,6 +6,8 @@ import scipy.sparse as ssp ...@@ -6,6 +6,8 @@ import scipy.sparse as ssp
import itertools import itertools
import backend as F import backend as F
import networkx as nx import networkx as nx
from dgl import DGLError
def create_test_heterograph(): def create_test_heterograph():
# test heterograph from the docstring, plus a user -- wishes -- game relation # test heterograph from the docstring, plus a user -- wishes -- game relation
...@@ -93,6 +95,36 @@ def test_create(): ...@@ -93,6 +95,36 @@ def test_create():
assert g.number_of_nodes('l1') == 3 assert g.number_of_nodes('l1') == 3
assert g.number_of_nodes('l2') == 4 assert g.number_of_nodes('l2') == 4
# test if validate flag works
# homo graph
fail = False
try:
g = dgl.graph(
([0, 0, 0, 1, 1, 2], [0, 1, 2, 0, 1, 2]),
card=2,
validate=True
)
except DGLError:
fail = True
finally:
assert fail, "should catch a DGLError because node ID is out of bound."
# bipartite graph
def _test_validate_bipartite(card):
fail = False
try:
g = dgl.bipartite(
([0, 0, 1, 1, 2], [1, 1, 2, 2, 3]),
card=card,
validate=True
)
except DGLError:
fail = True
finally:
assert fail, "should catch a DGLError because node ID is out of bound."
_test_validate_bipartite((3, 3))
_test_validate_bipartite((2, 4))
def test_query(): def test_query():
g = create_test_heterograph() g = create_test_heterograph()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment