Unverified Commit cfb24790 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] Several bug fix for remove nodes/edges in DGLGraph. (#1521)

* upd

* upd

* better
parent 20ec7bb0
......@@ -5,6 +5,7 @@ from __future__ import absolute_import
from collections import defaultdict
from contextlib import contextmanager
from typing import Iterable
from functools import wraps
import networkx as nx
import dgl
......@@ -819,6 +820,7 @@ class DGLBaseGraph(object):
def mutation(func):
"""A decorator to decorate functions that might change graph structure."""
@wraps(func)
def inner(g, *args, **kwargs):
if g.is_readonly:
raise DGLError("Readonly graph. Mutation is not allowed.")
......@@ -1302,13 +1304,17 @@ class DGLGraph(DGLBaseGraph):
induced_nodes = utils.set_diff(utils.toindex(self.nodes()), utils.toindex(vids))
sgi = self._graph.node_subgraph(induced_nodes)
num_nodes = len(sgi.induced_nodes)
num_edges = len(sgi.induced_edges)
if isinstance(self._node_frame, FrameRef):
self._node_frame = FrameRef(Frame(self._node_frame[sgi.induced_nodes]))
self._node_frame = FrameRef(Frame(self._node_frame[sgi.induced_nodes],
num_rows=num_nodes))
else:
self._node_frame = FrameRef(self._node_frame, sgi.induced_nodes)
if isinstance(self._edge_frame, FrameRef):
self._edge_frame = FrameRef(Frame(self._edge_frame[sgi.induced_edges]))
self._edge_frame = FrameRef(Frame(self._edge_frame[sgi.induced_edges],
num_rows=num_edges))
else:
self._edge_frame = FrameRef(self._edge_frame, sgi.induced_edges)
......@@ -1365,13 +1371,17 @@ class DGLGraph(DGLBaseGraph):
utils.toindex(range(self.number_of_edges())), utils.toindex(eids))
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=True)
num_nodes = len(sgi.induced_nodes)
num_edges = len(sgi.induced_edges)
if isinstance(self._node_frame, FrameRef):
self._node_frame = FrameRef(Frame(self._node_frame[sgi.induced_nodes]))
self._node_frame = FrameRef(Frame(self._node_frame[sgi.induced_nodes],
num_rows=num_nodes))
else:
self._node_frame = FrameRef(self._node_frame, sgi.induced_nodes)
if isinstance(self._edge_frame, FrameRef):
self._edge_frame = FrameRef(Frame(self._edge_frame[sgi.induced_edges]))
self._edge_frame = FrameRef(Frame(self._edge_frame[sgi.induced_edges],
num_rows=num_edges))
else:
self._edge_frame = FrameRef(self._edge_frame, sgi.induced_edges)
......
......@@ -163,6 +163,24 @@ def test_edge_frame():
g.remove_edges(range(3, 7))
assert F.allclose(g.edata['h'], F.zerocopy_from_numpy(new_data))
def test_frame_size():
# reproduce https://github.com/dmlc/dgl/issues/1287.
# remove nodes
g = dgl.DGLGraph()
g.add_nodes(5)
g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
g.remove_nodes([0, 1])
assert g._node_frame.num_rows == 3
assert g._edge_frame.num_rows == 1
# remove edges
g = dgl.DGLGraph()
g.add_nodes(5)
g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
g.remove_edges([0, 1])
assert g._node_frame.num_rows == 5
assert g._edge_frame.num_rows == 3
if __name__ == '__main__':
test_node_removal()
test_edge_removal()
......@@ -171,3 +189,4 @@ if __name__ == '__main__':
test_node_and_edge_removal()
test_node_frame()
test_edge_frame()
test_frame_size()
......@@ -192,7 +192,6 @@ print(g_multi.edata['w'])
###############################################################################
# .. note::
#
# * Nodes and edges can be added but not removed.
# * Updating a feature of different schemes raises the risk of error on individual nodes (or
# node subset).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment