Unverified Commit cfb24790 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] Several bug fix for remove nodes/edges in DGLGraph. (#1521)

* upd

* upd

* better
parent 20ec7bb0
...@@ -5,6 +5,7 @@ from __future__ import absolute_import ...@@ -5,6 +5,7 @@ from __future__ import absolute_import
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Iterable from typing import Iterable
from functools import wraps
import networkx as nx import networkx as nx
import dgl import dgl
...@@ -819,6 +820,7 @@ class DGLBaseGraph(object): ...@@ -819,6 +820,7 @@ class DGLBaseGraph(object):
def mutation(func): def mutation(func):
"""A decorator to decorate functions that might change graph structure.""" """A decorator to decorate functions that might change graph structure."""
@wraps(func)
def inner(g, *args, **kwargs): def inner(g, *args, **kwargs):
if g.is_readonly: if g.is_readonly:
raise DGLError("Readonly graph. Mutation is not allowed.") raise DGLError("Readonly graph. Mutation is not allowed.")
...@@ -1302,13 +1304,17 @@ class DGLGraph(DGLBaseGraph): ...@@ -1302,13 +1304,17 @@ class DGLGraph(DGLBaseGraph):
induced_nodes = utils.set_diff(utils.toindex(self.nodes()), utils.toindex(vids)) induced_nodes = utils.set_diff(utils.toindex(self.nodes()), utils.toindex(vids))
sgi = self._graph.node_subgraph(induced_nodes) sgi = self._graph.node_subgraph(induced_nodes)
num_nodes = len(sgi.induced_nodes)
num_edges = len(sgi.induced_edges)
if isinstance(self._node_frame, FrameRef): if isinstance(self._node_frame, FrameRef):
self._node_frame = FrameRef(Frame(self._node_frame[sgi.induced_nodes])) self._node_frame = FrameRef(Frame(self._node_frame[sgi.induced_nodes],
num_rows=num_nodes))
else: else:
self._node_frame = FrameRef(self._node_frame, sgi.induced_nodes) self._node_frame = FrameRef(self._node_frame, sgi.induced_nodes)
if isinstance(self._edge_frame, FrameRef): if isinstance(self._edge_frame, FrameRef):
self._edge_frame = FrameRef(Frame(self._edge_frame[sgi.induced_edges])) self._edge_frame = FrameRef(Frame(self._edge_frame[sgi.induced_edges],
num_rows=num_edges))
else: else:
self._edge_frame = FrameRef(self._edge_frame, sgi.induced_edges) self._edge_frame = FrameRef(self._edge_frame, sgi.induced_edges)
...@@ -1365,13 +1371,17 @@ class DGLGraph(DGLBaseGraph): ...@@ -1365,13 +1371,17 @@ class DGLGraph(DGLBaseGraph):
utils.toindex(range(self.number_of_edges())), utils.toindex(eids)) utils.toindex(range(self.number_of_edges())), utils.toindex(eids))
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=True) sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=True)
num_nodes = len(sgi.induced_nodes)
num_edges = len(sgi.induced_edges)
if isinstance(self._node_frame, FrameRef): if isinstance(self._node_frame, FrameRef):
self._node_frame = FrameRef(Frame(self._node_frame[sgi.induced_nodes])) self._node_frame = FrameRef(Frame(self._node_frame[sgi.induced_nodes],
num_rows=num_nodes))
else: else:
self._node_frame = FrameRef(self._node_frame, sgi.induced_nodes) self._node_frame = FrameRef(self._node_frame, sgi.induced_nodes)
if isinstance(self._edge_frame, FrameRef): if isinstance(self._edge_frame, FrameRef):
self._edge_frame = FrameRef(Frame(self._edge_frame[sgi.induced_edges])) self._edge_frame = FrameRef(Frame(self._edge_frame[sgi.induced_edges],
num_rows=num_edges))
else: else:
self._edge_frame = FrameRef(self._edge_frame, sgi.induced_edges) self._edge_frame = FrameRef(self._edge_frame, sgi.induced_edges)
......
...@@ -163,6 +163,24 @@ def test_edge_frame(): ...@@ -163,6 +163,24 @@ def test_edge_frame():
g.remove_edges(range(3, 7)) g.remove_edges(range(3, 7))
assert F.allclose(g.edata['h'], F.zerocopy_from_numpy(new_data)) assert F.allclose(g.edata['h'], F.zerocopy_from_numpy(new_data))
def test_frame_size():
# reproduce https://github.com/dmlc/dgl/issues/1287.
# remove nodes
g = dgl.DGLGraph()
g.add_nodes(5)
g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
g.remove_nodes([0, 1])
assert g._node_frame.num_rows == 3
assert g._edge_frame.num_rows == 1
# remove edges
g = dgl.DGLGraph()
g.add_nodes(5)
g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
g.remove_edges([0, 1])
assert g._node_frame.num_rows == 5
assert g._edge_frame.num_rows == 3
if __name__ == '__main__': if __name__ == '__main__':
test_node_removal() test_node_removal()
test_edge_removal() test_edge_removal()
...@@ -171,3 +189,4 @@ if __name__ == '__main__': ...@@ -171,3 +189,4 @@ if __name__ == '__main__':
test_node_and_edge_removal() test_node_and_edge_removal()
test_node_frame() test_node_frame()
test_edge_frame() test_edge_frame()
test_frame_size()
...@@ -192,7 +192,6 @@ print(g_multi.edata['w']) ...@@ -192,7 +192,6 @@ print(g_multi.edata['w'])
############################################################################### ###############################################################################
# .. note:: # .. note::
# #
# * Nodes and edges can be added but not removed.
# * Updating a feature of different schemes raises the risk of error on individual nodes (or # * Updating a feature of different schemes raises the risk of error on individual nodes (or
# node subset). # node subset).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment