Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
cfb24790
Unverified
Commit
cfb24790
authored
May 15, 2020
by
Zihao Ye
Committed by
GitHub
May 15, 2020
Browse files
[hotfix] Several bug fix for remove nodes/edges in DGLGraph. (#1521)
* upd * upd * better
parent
20ec7bb0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
5 deletions
+33
-5
python/dgl/graph.py
python/dgl/graph.py
+14
-4
tests/compute/test_removal.py
tests/compute/test_removal.py
+19
-0
tutorials/basics/2_basics.py
tutorials/basics/2_basics.py
+0
-1
No files found.
python/dgl/graph.py
View file @
cfb24790
...
...
@@ -5,6 +5,7 @@ from __future__ import absolute_import
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
typing
import
Iterable
from
functools
import
wraps
import
networkx
as
nx
import
dgl
...
...
@@ -819,6 +820,7 @@ class DGLBaseGraph(object):
def
mutation
(
func
):
"""A decorator to decorate functions that might change graph structure."""
@
wraps
(
func
)
def
inner
(
g
,
*
args
,
**
kwargs
):
if
g
.
is_readonly
:
raise
DGLError
(
"Readonly graph. Mutation is not allowed."
)
...
...
@@ -1302,13 +1304,17 @@ class DGLGraph(DGLBaseGraph):
induced_nodes
=
utils
.
set_diff
(
utils
.
toindex
(
self
.
nodes
()),
utils
.
toindex
(
vids
))
sgi
=
self
.
_graph
.
node_subgraph
(
induced_nodes
)
num_nodes
=
len
(
sgi
.
induced_nodes
)
num_edges
=
len
(
sgi
.
induced_edges
)
if
isinstance
(
self
.
_node_frame
,
FrameRef
):
self
.
_node_frame
=
FrameRef
(
Frame
(
self
.
_node_frame
[
sgi
.
induced_nodes
]))
self
.
_node_frame
=
FrameRef
(
Frame
(
self
.
_node_frame
[
sgi
.
induced_nodes
],
num_rows
=
num_nodes
))
else
:
self
.
_node_frame
=
FrameRef
(
self
.
_node_frame
,
sgi
.
induced_nodes
)
if
isinstance
(
self
.
_edge_frame
,
FrameRef
):
self
.
_edge_frame
=
FrameRef
(
Frame
(
self
.
_edge_frame
[
sgi
.
induced_edges
]))
self
.
_edge_frame
=
FrameRef
(
Frame
(
self
.
_edge_frame
[
sgi
.
induced_edges
],
num_rows
=
num_edges
))
else
:
self
.
_edge_frame
=
FrameRef
(
self
.
_edge_frame
,
sgi
.
induced_edges
)
...
...
@@ -1365,13 +1371,17 @@ class DGLGraph(DGLBaseGraph):
utils
.
toindex
(
range
(
self
.
number_of_edges
())),
utils
.
toindex
(
eids
))
sgi
=
self
.
_graph
.
edge_subgraph
(
induced_edges
,
preserve_nodes
=
True
)
num_nodes
=
len
(
sgi
.
induced_nodes
)
num_edges
=
len
(
sgi
.
induced_edges
)
if
isinstance
(
self
.
_node_frame
,
FrameRef
):
self
.
_node_frame
=
FrameRef
(
Frame
(
self
.
_node_frame
[
sgi
.
induced_nodes
]))
self
.
_node_frame
=
FrameRef
(
Frame
(
self
.
_node_frame
[
sgi
.
induced_nodes
],
num_rows
=
num_nodes
))
else
:
self
.
_node_frame
=
FrameRef
(
self
.
_node_frame
,
sgi
.
induced_nodes
)
if
isinstance
(
self
.
_edge_frame
,
FrameRef
):
self
.
_edge_frame
=
FrameRef
(
Frame
(
self
.
_edge_frame
[
sgi
.
induced_edges
]))
self
.
_edge_frame
=
FrameRef
(
Frame
(
self
.
_edge_frame
[
sgi
.
induced_edges
],
num_rows
=
num_edges
))
else
:
self
.
_edge_frame
=
FrameRef
(
self
.
_edge_frame
,
sgi
.
induced_edges
)
...
...
tests/compute/test_removal.py
View file @
cfb24790
...
...
@@ -163,6 +163,24 @@ def test_edge_frame():
g
.
remove_edges
(
range
(
3
,
7
))
assert
F
.
allclose
(
g
.
edata
[
'h'
],
F
.
zerocopy_from_numpy
(
new_data
))
def
test_frame_size
():
# reproduce https://github.com/dmlc/dgl/issues/1287.
# remove nodes
g
=
dgl
.
DGLGraph
()
g
.
add_nodes
(
5
)
g
.
add_edges
([
0
,
2
,
3
,
1
,
1
],
[
1
,
0
,
3
,
1
,
0
])
g
.
remove_nodes
([
0
,
1
])
assert
g
.
_node_frame
.
num_rows
==
3
assert
g
.
_edge_frame
.
num_rows
==
1
# remove edges
g
=
dgl
.
DGLGraph
()
g
.
add_nodes
(
5
)
g
.
add_edges
([
0
,
2
,
3
,
1
,
1
],
[
1
,
0
,
3
,
1
,
0
])
g
.
remove_edges
([
0
,
1
])
assert
g
.
_node_frame
.
num_rows
==
5
assert
g
.
_edge_frame
.
num_rows
==
3
if
__name__
==
'__main__'
:
test_node_removal
()
test_edge_removal
()
...
...
@@ -171,3 +189,4 @@ if __name__ == '__main__':
test_node_and_edge_removal
()
test_node_frame
()
test_edge_frame
()
test_frame_size
()
tutorials/basics/2_basics.py
View file @
cfb24790
...
...
@@ -192,7 +192,6 @@ print(g_multi.edata['w'])
###############################################################################
# .. note::
#
# * Nodes and edges can be added but not removed.
# * Updating a feature of different schemes raises the risk of error on individual nodes (or
# node subset).
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment