Unverified Commit fe207b45 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[HeteroGraph] Improve Error Messages for find_edges (#1626)



* Improve error message

* Fix

* Fix
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 77006db2
......@@ -1514,6 +1514,18 @@ class DGLHeteroGraph(object):
(tensor([0, 1]), tensor([0, 2]))
"""
check_same_dtype(self._idtype_str, eid)
if F.is_tensor(eid):
max_eid = F.max(eid, dim=0)
else:
max_eid = np.max(eid, axis=0)
max_valid_eid = self.number_of_edges(etype) - 1
valid_ids = max_eid <= max_valid_eid
if etype is None:
assert valid_ids, \
'Expect edge ids to be in [0, ..., {:d}], got {}'.format(max_valid_eid, max_eid)
else:
assert valid_ids, 'Expect edge ids to be in [0, ..., {:d}]' \
' for type {}, got {}'.format(max_valid_eid, etype, max_eid)
eid = utils.toindex(eid, self._idtype_str)
src, dst, _ = self._graph.find_edges(self.get_etype_id(etype), eid)
return src.tousertensor(), dst.tousertensor()
......
......@@ -238,9 +238,10 @@ def test_query(index_dtype):
assert F.asnumpy(e).tolist() == list(range(n_edges))
# find_edges
u, v = g.find_edges(list(range(n_edges)), etype)
assert F.asnumpy(u).tolist() == srcs
assert F.asnumpy(v).tolist() == dsts
for edge_ids in [list(range(n_edges)), np.arange(n_edges), F.astype(F.arange(0, n_edges), g.idtype)]:
u, v = g.find_edges(edge_ids, etype)
assert F.asnumpy(u).tolist() == srcs
assert F.asnumpy(v).tolist() == dsts
# all_edges.
for order in ['eid']:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment