Unverified Commit d27b4859 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[BUG] Fix remove_edges crashing with empty edge ID tensors (#1384)

parent a0721405
...@@ -922,15 +922,20 @@ def remove_edges(g, edge_ids): ...@@ -922,15 +922,20 @@ def remove_edges(g, edge_ids):
"Graph has more than one edge type; specify a dict for edge_id instead.") "Graph has more than one edge type; specify a dict for edge_id instead.")
edge_ids = {g.canonical_etypes[0]: edge_ids} edge_ids = {g.canonical_etypes[0]: edge_ids}
edge_ids_nd = [None] * len(g.etypes) edge_ids_nd = [nd.null()] * len(g.etypes)
for key, value in edge_ids.items(): for key, value in edge_ids.items():
edge_ids_nd[g.get_etype_id(key)] = F.zerocopy_to_dgl_ndarray(value) edge_ids_nd[g.get_etype_id(key)] = F.zerocopy_to_dgl_ndarray(value)
new_graph_index, induced_eids_nd = _CAPI_DGLRemoveEdges(g._graph, edge_ids_nd) new_graph_index, induced_eids_nd = _CAPI_DGLRemoveEdges(g._graph, edge_ids_nd)
new_graph = DGLHeteroGraph(new_graph_index, g.ntypes, g.etypes) new_graph = DGLHeteroGraph(new_graph_index, g.ntypes, g.etypes)
for i, canonical_etype in enumerate(g.canonical_etypes): for i, canonical_etype in enumerate(g.canonical_etypes):
new_graph.edges[canonical_etype].data[EID] = F.zerocopy_from_dgl_ndarray( data = induced_eids_nd[i].data
induced_eids_nd[i].data) if len(data) == 0:
# Empty means that no edges are removed and edges are not shuffled.
new_graph.edges[canonical_etype].data[EID] = F.arange(
0, g.number_of_edges(canonical_etype))
else:
new_graph.edges[canonical_etype].data[EID] = F.zerocopy_from_dgl_ndarray(data)
return new_graph return new_graph
......
...@@ -542,6 +542,16 @@ def test_remove_edges(): ...@@ -542,6 +542,16 @@ def test_remove_edges():
check(g2, 'AB', g, [3]) check(g2, 'AB', g, [3])
check(g2, 'BA', g, [1]) check(g2, 'BA', g, [1])
g3 = dgl.remove_edges(g, {'AA': F.tensor([]), 'AB': F.tensor([3]), 'BA': F.tensor([1])})
check(g3, 'AA', g, [])
check(g3, 'AB', g, [3])
check(g3, 'BA', g, [1])
g4 = dgl.remove_edges(g, {'AB': F.tensor([3])})
check(g4, 'AA', g, [])
check(g4, 'AB', g, [3])
check(g4, 'BA', g, [])
if __name__ == '__main__': if __name__ == '__main__':
test_line_graph() test_line_graph()
test_no_backtracking() test_no_backtracking()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment