"vscode:/vscode.git/clone" did not exist on "50cbb0ecfecef497bc4fb08816d4217d1c84f593"
Unverified Commit b9c65e91 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[BUG] Another fix on remove edges when all edges are removed (#1386)

* [BUG] Another fix on remove edges when all edges are removed

* fix mxnet
parent d27b4859
......@@ -335,6 +335,9 @@ def sort_1d(input):
return val, idx
def arange(start, stop):
if start >= stop:
return nd.array([], dtype=np.int64)
else:
return nd.arange(start, stop, dtype=np.int64)
def rand_shuffle(arr):
......
......@@ -931,9 +931,12 @@ def remove_edges(g, edge_ids):
for i, canonical_etype in enumerate(g.canonical_etypes):
data = induced_eids_nd[i].data
if len(data) == 0:
# Empty means that no edges are removed and edges are not shuffled.
# Empty means that either
# (1) no edges are removed and edges are not shuffled.
# (2) all edges are removed.
# The following statement deals with both cases.
new_graph.edges[canonical_etype].data[EID] = F.arange(
0, g.number_of_edges(canonical_etype))
0, new_graph.number_of_edges(canonical_etype))
else:
new_graph.edges[canonical_etype].data[EID] = F.zerocopy_from_dgl_ndarray(data)
......
......@@ -547,9 +547,9 @@ def test_remove_edges():
check(g3, 'AB', g, [3])
check(g3, 'BA', g, [1])
g4 = dgl.remove_edges(g, {'AB': F.tensor([3])})
g4 = dgl.remove_edges(g, {'AB': F.tensor([3, 1, 2, 0])})
check(g4, 'AA', g, [])
check(g4, 'AB', g, [3])
check(g4, 'AB', g, [3, 1, 2, 0])
check(g4, 'BA', g, [])
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment