Commit 31a7d509 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by Minjie Wang
Browse files

[Bug Fix] Add boundary check for heterograph build with card as input (#1192)

* Add boundary check for heterograph build with card as input

* Fix when u or v is empty

* fix test_kernel.py error print

* Revert "fix test_kernel.py error print"

This reverts commit a71c20292549c0ac62a5326c30669ca4bde8febc.

* Turn op validation check on graph and bipartite by default

* upd

* udp

* upd

* update test
parent c49582c9
......@@ -21,7 +21,7 @@ __all__ = [
'to_networkx',
]
def graph(data, ntype='_N', etype='_E', card=None, validate=False, **kwargs):
def graph(data, ntype='_N', etype='_E', card=None, validate=True, **kwargs):
"""Create a graph with one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
......@@ -47,8 +47,8 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=False, **kwargs):
the largest node ID plus 1. (Default: None)
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
some time.
If False and card is not None, user would receive a warning. (Default: False)
some time. (Default: True)
If False and card is not None, user would receive a warning.
kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of:
......@@ -132,7 +132,7 @@ def graph(data, ntype='_N', etype='_E', card=None, validate=False, **kwargs):
else:
raise DGLError('Unsupported graph data type:', type(data))
def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=False, **kwargs):
def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=True, **kwargs):
"""Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes
......@@ -163,8 +163,8 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=Fals
infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None)
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
some time.
If False and card is not None, user would receive a warning. (Default: False)
some time. (Default: True)
If False and card is not None, user would receive a warning.
kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of:
......@@ -433,11 +433,13 @@ def heterograph(data_dict, num_nodes_dict=None):
if isinstance(data, DGLHeteroGraph):
rel_graphs.append(data)
elif srctype == dsttype:
rel_graphs.append(graph(data, srctype, etype, card=num_nodes_dict[srctype]))
rel_graphs.append(graph(
data, srctype, etype,
card=num_nodes_dict[srctype], validate=False))
else:
rel_graphs.append(bipartite(
data, srctype, etype, dsttype,
card=(num_nodes_dict[srctype], num_nodes_dict[dsttype])))
card=(num_nodes_dict[srctype], num_nodes_dict[dsttype]), validate=False))
return hetero_from_relations(rel_graphs)
......@@ -590,11 +592,11 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
if stid == dtid:
rel_graph = graph(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid],
card=ntype_count[stid])
card=ntype_count[stid], validate=False)
else:
rel_graph = bipartite(
(src_of_etype, dst_of_etype), ntypes[stid], etypes[etid], ntypes[dtid],
card=(ntype_count[stid], ntype_count[dtid]))
card=(ntype_count[stid], ntype_count[dtid]), validate=False)
rel_graphs.append(rel_graph)
hg = hetero_from_relations(rel_graphs)
......@@ -681,7 +683,7 @@ def to_homo(G):
etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, F.cpu()))
eids.append(F.arange(0, num_edges))
retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), card=total_num_nodes)
retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), card=total_num_nodes, validate=False)
retg.ndata[NTYPE] = F.cat(ntype_ids, 0)
retg.ndata[NID] = F.cat(nids, 0)
retg.edata[ETYPE] = F.cat(etype_ids, 0)
......@@ -701,7 +703,7 @@ def to_homo(G):
# Internal APIs
############################################################
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=False):
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=True):
"""Internal function to create a graph from incident nodes with types.
utype could be equal to vtype
......@@ -735,10 +737,12 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid
v = utils.toindex(v)
if validate:
if urange is not None and urange <= int(F.asnumpy(F.max(u.tousertensor(), dim=0))):
if urange is not None and len(u) > 0 and \
urange <= int(F.asnumpy(F.max(u.tousertensor(), dim=0))):
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
urange, int(F.asnumpy(F.max(u.tousertensor(), dim=0)))))
if vrange is not None and vrange <= int(F.asnumpy(F.max(v.tousertensor(), dim=0))):
if vrange is not None and len(v) > 0 and \
vrange <= int(F.asnumpy(F.max(v.tousertensor(), dim=0))):
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
vrange, int(F.asnumpy(F.max(v.tousertensor(), dim=0)))))
urange = urange or (
......@@ -757,7 +761,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, valid
else:
return DGLHeteroGraph(hgidx, [utype, vtype], [etype])
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, validate=False):
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, validate=True):
"""Internal function to create a heterograph from a list of edge tuples with types.
utype could be equal to vtype
......@@ -895,7 +899,7 @@ def create_from_networkx(nx_graph,
src = utils.toindex(src)
dst = utils.toindex(dst)
num_nodes = nx_graph.number_of_nodes()
g = create_from_edges(src, dst, ntype, etype, ntype, num_nodes, num_nodes)
g = create_from_edges(src, dst, ntype, etype, ntype, num_nodes, num_nodes, validate=False)
# handle features
# copy attributes
......@@ -1007,7 +1011,9 @@ def create_from_networkx_bipartite(nx_graph,
dst.append(bottom_map[e[1]])
src = utils.toindex(src)
dst = utils.toindex(dst)
g = create_from_edges(src, dst, utype, etype, vtype, len(top_nodes), len(bottom_nodes))
g = create_from_edges(
src, dst, utype, etype, vtype,
len(top_nodes), len(bottom_nodes), validate=False)
# TODO attributes
assert node_attrs is None, 'Retrieval of node attributes are not supported yet.'
......
......@@ -630,6 +630,27 @@ def test_to_device():
hg = hg.to(F.cuda())
assert hg is not None
def test_convert_bound():
def _test_bipartite_bound(data, card):
try:
dgl.bipartite(data, card=card)
except dgl.DGLError:
return
assert False, 'bipartite bound test with wrong uid failed'
def _test_graph_bound(data, card):
try:
dgl.graph(data, card=card)
except dgl.DGLError:
return
assert False, 'graph bound test with wrong uid failed'
_test_bipartite_bound(([1,2],[1,2]),(2,3))
_test_bipartite_bound(([0,1],[1,4]),(2,3))
_test_graph_bound(([1,3],[1,2]), 3)
_test_graph_bound(([0,1],[1,3]),3)
def test_convert():
hg = create_test_heterograph()
hs = []
......@@ -1265,6 +1286,7 @@ if __name__ == '__main__':
test_view()
test_view1()
test_flatten()
test_convert_bound()
test_convert()
test_to_device()
test_transform()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment