Unverified Commit b9c65e91 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[BUG] Another fix on remove edges when all edges are removed (#1386)

* [BUG] Another fix on remove edges when all edges are removed

* fix mxnet
parent d27b4859
...@@ -335,7 +335,10 @@ def sort_1d(input): ...@@ -335,7 +335,10 @@ def sort_1d(input):
return val, idx return val, idx
def arange(start, stop): def arange(start, stop):
return nd.arange(start, stop, dtype=np.int64) if start >= stop:
return nd.array([], dtype=np.int64)
else:
return nd.arange(start, stop, dtype=np.int64)
def rand_shuffle(arr): def rand_shuffle(arr):
return mx.nd.random.shuffle(arr) return mx.nd.random.shuffle(arr)
......
...@@ -931,9 +931,12 @@ def remove_edges(g, edge_ids): ...@@ -931,9 +931,12 @@ def remove_edges(g, edge_ids):
for i, canonical_etype in enumerate(g.canonical_etypes): for i, canonical_etype in enumerate(g.canonical_etypes):
data = induced_eids_nd[i].data data = induced_eids_nd[i].data
if len(data) == 0: if len(data) == 0:
# Empty means that no edges are removed and edges are not shuffled. # Empty means that either
# (1) no edges are removed and edges are not shuffled.
# (2) all edges are removed.
# The following statement deals with both cases.
new_graph.edges[canonical_etype].data[EID] = F.arange( new_graph.edges[canonical_etype].data[EID] = F.arange(
0, g.number_of_edges(canonical_etype)) 0, new_graph.number_of_edges(canonical_etype))
else: else:
new_graph.edges[canonical_etype].data[EID] = F.zerocopy_from_dgl_ndarray(data) new_graph.edges[canonical_etype].data[EID] = F.zerocopy_from_dgl_ndarray(data)
......
...@@ -547,9 +547,9 @@ def test_remove_edges(): ...@@ -547,9 +547,9 @@ def test_remove_edges():
check(g3, 'AB', g, [3]) check(g3, 'AB', g, [3])
check(g3, 'BA', g, [1]) check(g3, 'BA', g, [1])
g4 = dgl.remove_edges(g, {'AB': F.tensor([3])}) g4 = dgl.remove_edges(g, {'AB': F.tensor([3, 1, 2, 0])})
check(g4, 'AA', g, []) check(g4, 'AA', g, [])
check(g4, 'AB', g, [3]) check(g4, 'AB', g, [3, 1, 2, 0])
check(g4, 'BA', g, []) check(g4, 'BA', g, [])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment