Commit 31a7d509 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by Minjie Wang
Browse files

[Bug Fix] Add boundary check for heterograph build with card as input (#1192)

* Add boundary check for heterograph build with card as input

* Fix when u or v is empty

* fix test_kernel.py error print

* Revert "fix test_kernel.py error print"

This reverts commit a71c20292549c0ac62a5326c30669ca4bde8febc.

* Turn op validation check on graph and bipartite by default

* upd

* udp

* upd

* update test
parent c49582c9
...@@ -21,7 +21,7 @@ __all__ = [ ...@@ -21,7 +21,7 @@ __all__ = [
'to_networkx', 'to_networkx',
] ]
def graph(data, ntype='_N', etype='_E', card=None, validate=False, **kwargs): def graph(data, ntype='_N', etype='_E', card=None, validate=True, **kwargs):
"""Create a graph with one type of nodes and edges. """Create a graph with one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph In the sparse matrix perspective, :func:`dgl.graph` creates a graph
...@@ -47,8 +47,8 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=False, **kwargs): ...@@ -47,8 +47,8 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=False, **kwargs):
the largest node ID plus 1. (Default: None) the largest node ID plus 1. (Default: None)
validate : bool, optional validate : bool, optional
If True, check if node ids are within cardinality, the check process may take If True, check if node ids are within cardinality, the check process may take
some time. some time. (Default: True)
If False and card is not None, user would receive a warning. (Default: False) If False and card is not None, user would receive a warning.
kwargs : key-word arguments, optional kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of: graph. It can consist of:
...@@ -132,7 +132,7 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=False, **kwargs): ...@@ -132,7 +132,7 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=False, **kwargs):
else: else:
raise DGLError('Unsupported graph data type:', type(data)) raise DGLError('Unsupported graph data type:', type(data))
def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=False, **kwargs): def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True, **kwargs):
"""Create a bipartite graph. """Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes The result graph is directed and edges must be from ``utype`` nodes
...@@ -163,8 +163,8 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=Fals ...@@ -163,8 +163,8 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=Fals
infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None) infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None)
validate : bool, optional validate : bool, optional
If True, check if node ids are within cardinality, the check process may take If True, check if node ids are within cardinality, the check process may take
some time. some time. (Default: True)
If False and card is not None, user would receive a warning. (Default: False) If False and card is not None, user would receive a warning.
kwargs : key-word arguments, optional kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of: graph. It can consist of:
...@@ -433,11 +433,13 @@ def heterograph(data_dict, num_nodes_dict=None): ...@@ -433,11 +433,13 @@ def heterograph(data_dict, num_nodes_dict=None):
if isinstance(data, DGLHeteroGraph): if isinstance(data, DGLHeteroGraph):
rel_graphs.append(data) rel_graphs.append(data)
elif srctype == dsttype: elif srctype == dsttype:
rel_graphs.append(graph(data, srctype, etype, card=num_nodes_dict[srctype])) rel_graphs.append(graph(
data, srctype, etype,
card=num_nodes_dict[srctype], validate=False))
else: else:
rel_graphs.append(bipartite( rel_graphs.append(bipartite(
data, srctype, etype, dsttype, data, srctype, etype, dsttype,
card=(num_nodes_dict[srctype], num_nodes_dict[dsttype]))) card=(num_nodes_dict[srctype], num_nodes_dict[dsttype]), validate=False))
return hetero_from_relations(rel_graphs) return hetero_from_relations(rel_graphs)
...@@ -590,11 +592,11 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph ...@@ -590,11 +592,11 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
if stid == dtid: if stid == dtid:
rel_graph = graph( rel_graph = graph(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid], (src_of_etype, dst_of_etype), ntypes[stid], etypes[etid],
card=ntype_count[stid]) card=ntype_count[stid], validate=False)
else: else:
rel_graph = bipartite( rel_graph = bipartite(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid], ntypes[dtid], (src_of_etype, dst_of_etype), ntypes[stid], etypes[etid], ntypes[dtid],
card=(ntype_count[stid], ntype_count[dtid])) card=(ntype_count[stid], ntype_count[dtid]), validate=False)
rel_graphs.append(rel_graph) rel_graphs.append(rel_graph)
hg = hetero_from_relations(rel_graphs) hg = hetero_from_relations(rel_graphs)
...@@ -681,7 +683,7 @@ def to_homo(G): ...@@ -681,7 +683,7 @@ def to_homo(G):
etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, F.cpu())) etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, F.cpu()))
eids.append(F.arange(0, num_edges)) eids.append(F.arange(0, num_edges))
retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), card=total_num_nodes) retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), card=total_num_nodes, validate=False)
retg.ndata[NTYPE] = F.cat(ntype_ids, 0) retg.ndata[NTYPE] = F.cat(ntype_ids, 0)
retg.ndata[NID] = F.cat(nids, 0) retg.ndata[NID] = F.cat(nids, 0)
retg.edata[ETYPE] = F.cat(etype_ids, 0) retg.edata[ETYPE] = F.cat(etype_ids, 0)
...@@ -701,7 +703,7 @@ def to_homo(G): ...@@ -701,7 +703,7 @@ def to_homo(G):
# Internal APIs # Internal APIs
############################################################ ############################################################
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=False): def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=True):
"""Internal function to create a graph from incident nodes with types. """Internal function to create a graph from incident nodes with types.
utype could be equal to vtype utype could be equal to vtype
...@@ -735,10 +737,12 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid ...@@ -735,10 +737,12 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid
v = utils.toindex(v) v = utils.toindex(v)
if validate: if validate:
if urange is not None and urange <= int(F.asnumpy(F.max(u.tousertensor(), dim=0))): if urange is not None and len(u) > 0 and \
urange <= int(F.asnumpy(F.max(u.tousertensor(), dim=0))):
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format( raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
urange, int(F.asnumpy(F.max(u.tousertensor(), dim=0))))) urange, int(F.asnumpy(F.max(u.tousertensor(), dim=0)))))
if vrange is not None and vrange <= int(F.asnumpy(F.max(v.tousertensor(), dim=0))): if vrange is not None and len(v) > 0 and \
vrange <= int(F.asnumpy(F.max(v.tousertensor(), dim=0))):
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format( raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
vrange, int(F.asnumpy(F.max(v.tousertensor(), dim=0))))) vrange, int(F.asnumpy(F.max(v.tousertensor(), dim=0)))))
urange = urange or ( urange = urange or (
...@@ -757,7 +761,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid ...@@ -757,7 +761,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid
else: else:
return DGLHeteroGraph(hgidx, [utype, vtype], [etype]) return DGLHeteroGraph(hgidx, [utype, vtype], [etype])
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, validate=False): def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, validate=True):
"""Internal function to create a heterograph from a list of edge tuples with types. """Internal function to create a heterograph from a list of edge tuples with types.
utype could be equal to vtype utype could be equal to vtype
...@@ -895,7 +899,7 @@ def create_from_networkx(nx_graph, ...@@ -895,7 +899,7 @@ def create_from_networkx(nx_graph,
src = utils.toindex(src) src = utils.toindex(src)
dst = utils.toindex(dst) dst = utils.toindex(dst)
num_nodes = nx_graph.number_of_nodes() num_nodes = nx_graph.number_of_nodes()
g = create_from_edges(src, dst, ntype, etype, ntype, num_nodes, num_nodes) g = create_from_edges(src, dst, ntype, etype, ntype, num_nodes, num_nodes, validate=False)
# handle features # handle features
# copy attributes # copy attributes
...@@ -1007,7 +1011,9 @@ def create_from_networkx_bipartite(nx_graph, ...@@ -1007,7 +1011,9 @@ def create_from_networkx_bipartite(nx_graph,
dst.append(bottom_map[e[1]]) dst.append(bottom_map[e[1]])
src = utils.toindex(src) src = utils.toindex(src)
dst = utils.toindex(dst) dst = utils.toindex(dst)
g = create_from_edges(src, dst, utype, etype, vtype, len(top_nodes), len(bottom_nodes)) g = create_from_edges(
src, dst, utype, etype, vtype,
len(top_nodes), len(bottom_nodes), validate=False)
# TODO attributes # TODO attributes
assert node_attrs is None, 'Retrieval of node attributes are not supported yet.' assert node_attrs is None, 'Retrieval of node attributes are not supported yet.'
......
...@@ -630,6 +630,27 @@ def test_to_device(): ...@@ -630,6 +630,27 @@ def test_to_device():
hg = hg.to(F.cuda()) hg = hg.to(F.cuda())
assert hg is not None assert hg is not None
def test_convert_bound():
def _test_bipartite_bound(data, card):
try:
dgl.bipartite(data, card=card)
except dgl.DGLError:
return
assert False, 'bipartite bound test with wrong uid failed'
def _test_graph_bound(data, card):
try:
dgl.graph(data, card=card)
except dgl.DGLError:
return
assert False, 'graph bound test with wrong uid failed'
_test_bipartite_bound(([1,2],[1,2]),(2,3))
_test_bipartite_bound(([0,1],[1,4]),(2,3))
_test_graph_bound(([1,3],[1,2]), 3)
_test_graph_bound(([0,1],[1,3]),3)
def test_convert(): def test_convert():
hg = create_test_heterograph() hg = create_test_heterograph()
hs = [] hs = []
...@@ -1265,6 +1286,7 @@ if __name__ == '__main__': ...@@ -1265,6 +1286,7 @@ if __name__ == '__main__':
test_view() test_view()
test_view1() test_view1()
test_flatten() test_flatten()
test_convert_bound()
test_convert() test_convert()
test_to_device() test_to_device()
test_transform() test_transform()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment