Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
72f63455
Commit
72f63455
authored
Oct 05, 2018
by
Minjie Wang
Browse files
Fix for subgraph test and some docs
parent
b0e02e5b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
108 additions
and
27 deletions
+108
-27
python/dgl/graph.py
python/dgl/graph.py
+3
-1
python/dgl/subgraph.py
python/dgl/subgraph.py
+95
-14
tests/pytorch/test_subgraph.py
tests/pytorch/test_subgraph.py
+10
-12
No files found.
python/dgl/graph.py
View file @
72f63455
...
@@ -1179,7 +1179,9 @@ class DGLGraph(object):
...
@@ -1179,7 +1179,9 @@ class DGLGraph(object):
G : DGLSubGraph
G : DGLSubGraph
The subgraph.
The subgraph.
"""
"""
return
dgl
.
DGLSubGraph
(
self
,
nodes
)
induced_nodes
=
utils
.
toindex
(
nodes
)
gi
,
induced_edges
=
self
.
_graph
.
node_subgraph
(
induced_nodes
)
return
dgl
.
DGLSubGraph
(
self
,
induced_nodes
,
induced_edges
,
gi
)
def
merge
(
self
,
subgraphs
,
reduce_func
=
'sum'
):
def
merge
(
self
,
subgraphs
,
reduce_func
=
'sum'
):
"""Merge subgraph features back to this parent graph.
"""Merge subgraph features back to this parent graph.
...
...
python/dgl/subgraph.py
View file @
72f63455
"""
DGLS
ub
G
raph"""
"""
Class for s
ub
g
raph
data structure.
"""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
networkx
as
nx
import
networkx
as
nx
...
@@ -9,20 +9,99 @@ from .graph import DGLGraph
...
@@ -9,20 +9,99 @@ from .graph import DGLGraph
from
.
import
utils
from
.
import
utils
class
DGLSubGraph
(
DGLGraph
):
class
DGLSubGraph
(
DGLGraph
):
# TODO(gaiyu): ReadOnlyGraph
"""The subgraph class.
def
__init__
(
self
,
parent
,
There are two subgraph modes: shared and non-shared.
nodes
):
super
(
DGLSubGraph
,
self
).
__init__
()
For the "non-shared" mode, the user needs to explicitly call
# relabel nodes
``copy_from_parent`` to copy node/edge features from its parent graph.
* If the user tries to get node/edge features before ``copy_from_parent``,
s/he will get nothing.
* If the subgraph already has its own node/edge features, ``copy_from_parent``
will override them.
* Any update on the subgraph's node/edge features will not be seen
by the parent graph. As such, the memory consumption is of the order
of the subgraph size.
* To write the subgraph's node/edge features back to parent graph. There are two options:
(1) Use ``copy_to_parent`` API to write node/edge features back.
(2) [TODO] Use ``dgl.merge`` to merge multiple subgraphs back to one parent.
The "shared" mode is currently not supported.
The subgraph is read-only so mutation is not allowed.
Parameters
----------
parent : DGLGraph
The parent graph
parent_nid : utils.Index
The induced parent node ids in this subgraph.
parent_eid : utils.Index
The induced parent edge ids in this subgraph.
graph_idx : GraphIndex
The graph index.
shared : bool, optional
Whether the subgraph shares node/edge features with the parent graph.
"""
def
__init__
(
self
,
parent
,
parent_nid
,
parent_eid
,
graph_idx
,
shared
=
False
):
super
(
DGLSubGraph
,
self
).
__init__
(
graph_data
=
graph_idx
)
self
.
_parent
=
parent
self
.
_parent
=
parent
self
.
_parent_nid
=
utils
.
toindex
(
nodes
)
self
.
_parent_nid
=
parent_nid
self
.
_graph
,
self
.
_parent_eid
=
parent
.
_graph
.
node_subgraph
(
self
.
_parent_nid
)
self
.
_parent_eid
=
parent_eid
self
.
reset_messages
()
# override APIs
def
add_nodes
(
self
,
num
,
reprs
=
None
):
"""Add nodes. Disabled because BatchedDGLGraph is read-only."""
raise
RuntimeError
(
'Readonly graph. Mutation is not allowed.'
)
def
add_edge
(
self
,
u
,
v
,
reprs
=
None
):
"""Add one edge. Disabled because BatchedDGLGraph is read-only."""
raise
RuntimeError
(
'Readonly graph. Mutation is not allowed.'
)
def
add_edges
(
self
,
u
,
v
,
reprs
=
None
):
"""Add many edges. Disabled because BatchedDGLGraph is read-only."""
raise
RuntimeError
(
'Readonly graph. Mutation is not allowed.'
)
@
property
def
parent_nid
(
self
):
"""Get the parent node ids.
The returned tensor can be used as a map from the node id
in this subgraph to the node id in the parent graph.
Returns
-------
Tensor
The parent node id array.
"""
return
self
.
_parent_nid
.
tousertensor
()
@
property
def
parent_eid
(
self
):
"""Get the parent edge ids.
The returned tensor can be used as a map from the edge id
in this subgraph to the edge id in the parent graph.
Returns
-------
Tensor
The parent edge id array.
"""
return
self
.
_parent_eid
.
tousertensor
()
def
copy_to_parent
(
self
,
inplace
=
False
):
def
copy_to_parent
(
self
,
inplace
=
False
):
self
.
_parent
.
_node_frame
.
update_rows
(
self
.
_parent_nid
,
self
.
_node_frame
,
inplace
=
inplace
)
"""Write node/edge features to the parent graph.
self
.
_parent
.
_edge_frame
.
update_rows
(
self
.
_parent_eid
,
self
.
_edge_frame
,
inplace
=
inplace
)
Parameters
----------
inplace : bool
If true, use inplace write (no gradient but faster)
"""
self
.
_parent
.
_node_frame
.
update_rows
(
self
.
_parent_nid
,
self
.
_node_frame
,
inplace
=
inplace
)
self
.
_parent
.
_edge_frame
.
update_rows
(
self
.
_parent_eid
,
self
.
_edge_frame
,
inplace
=
inplace
)
def
copy_from_parent
(
self
):
def
copy_from_parent
(
self
):
"""Copy node/edge features from the parent graph.
"""Copy node/edge features from the parent graph.
...
@@ -30,6 +109,8 @@ class DGLSubGraph(DGLGraph):
...
@@ -30,6 +109,8 @@ class DGLSubGraph(DGLGraph):
All old features will be removed.
All old features will be removed.
"""
"""
if
self
.
_parent
.
_node_frame
.
num_rows
!=
0
:
if
self
.
_parent
.
_node_frame
.
num_rows
!=
0
:
self
.
_node_frame
=
FrameRef
(
Frame
(
self
.
_parent
.
_node_frame
[
self
.
_parent_nid
]))
self
.
_node_frame
=
FrameRef
(
Frame
(
self
.
_parent
.
_node_frame
[
self
.
_parent_nid
]))
if
self
.
_parent
.
_edge_frame
.
num_rows
!=
0
:
if
self
.
_parent
.
_edge_frame
.
num_rows
!=
0
:
self
.
_edge_frame
=
FrameRef
(
Frame
(
self
.
_parent
.
_edge_frame
[
self
.
_parent_eid
]))
self
.
_edge_frame
=
FrameRef
(
Frame
(
self
.
_parent
.
_edge_frame
[
self
.
_parent_eid
]))
tests/pytorch/test_subgraph.py
View file @
72f63455
...
@@ -5,13 +5,9 @@ from dgl.graph import DGLGraph
...
@@ -5,13 +5,9 @@ from dgl.graph import DGLGraph
D
=
5
D
=
5
def
check_eq
(
a
,
b
):
return
a
.
shape
==
b
.
shape
and
np
.
allclose
(
a
.
numpy
(),
b
.
numpy
())
def
generate_graph
(
grad
=
False
):
def
generate_graph
(
grad
=
False
):
g
=
DGLGraph
()
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_nodes
(
10
)
g
.
add_node
(
i
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
0
,
i
)
...
@@ -29,8 +25,10 @@ def test_basics():
...
@@ -29,8 +25,10 @@ def test_basics():
h
=
g
.
get_n_repr
()[
'h'
]
h
=
g
.
get_n_repr
()[
'h'
]
l
=
g
.
get_e_repr
()[
'l'
]
l
=
g
.
get_e_repr
()[
'l'
]
nid
=
[
0
,
2
,
3
,
6
,
7
,
9
]
nid
=
[
0
,
2
,
3
,
6
,
7
,
9
]
eid
=
[
2
,
3
,
4
,
5
,
10
,
11
,
12
,
13
,
16
]
sg
=
g
.
subgraph
(
nid
)
sg
=
g
.
subgraph
(
nid
)
eid
=
{
2
,
3
,
4
,
5
,
10
,
11
,
12
,
13
,
16
}
assert
set
(
sg
.
parent_eid
.
numpy
())
==
eid
eid
=
sg
.
parent_eid
# the subgraph is empty initially
# the subgraph is empty initially
assert
len
(
sg
.
get_n_repr
())
==
0
assert
len
(
sg
.
get_n_repr
())
==
0
assert
len
(
sg
.
get_e_repr
())
==
0
assert
len
(
sg
.
get_e_repr
())
==
0
...
@@ -39,7 +37,7 @@ def test_basics():
...
@@ -39,7 +37,7 @@ def test_basics():
assert
len
(
sg
.
get_n_repr
())
==
1
assert
len
(
sg
.
get_n_repr
())
==
1
assert
len
(
sg
.
get_e_repr
())
==
1
assert
len
(
sg
.
get_e_repr
())
==
1
sh
=
sg
.
get_n_repr
()[
'h'
]
sh
=
sg
.
get_n_repr
()[
'h'
]
assert
check_eq
(
h
[
nid
],
sh
)
assert
th
.
allclose
(
h
[
nid
],
sh
)
'''
'''
s, d, eid
s, d, eid
0, 1, 0
0, 1, 0
...
@@ -60,11 +58,11 @@ def test_basics():
...
@@ -60,11 +58,11 @@ def test_basics():
8, 9, 15 3
8, 9, 15 3
9, 0, 16 1
9, 0, 16 1
'''
'''
assert
check_eq
(
l
[
eid
],
sg
.
get_e_repr
()[
'l'
])
assert
th
.
allclose
(
l
[
eid
],
sg
.
get_e_repr
()[
'l'
])
# update the node/edge features on the subgraph should NOT
# update the node/edge features on the subgraph should NOT
# reflect to the parent graph.
# reflect to the parent graph.
sg
.
set_n_repr
({
'h'
:
th
.
zeros
((
6
,
D
))})
sg
.
set_n_repr
({
'h'
:
th
.
zeros
((
6
,
D
))})
assert
check_eq
(
h
,
g
.
get_n_repr
()[
'h'
])
assert
th
.
allclose
(
h
,
g
.
get_n_repr
()[
'h'
])
def
test_merge
():
def
test_merge
():
g
=
generate_graph
()
g
=
generate_graph
()
...
@@ -85,10 +83,10 @@ def test_merge():
...
@@ -85,10 +83,10 @@ def test_merge():
h
=
g
.
get_n_repr
()[
'h'
][:,
0
]
h
=
g
.
get_n_repr
()[
'h'
][:,
0
]
l
=
g
.
get_e_repr
()[
'l'
][:,
0
]
l
=
g
.
get_e_repr
()[
'l'
][:,
0
]
assert
check_eq
(
h
,
th
.
tensor
([
3.
,
0.
,
3.
,
3.
,
2.
,
0.
,
1.
,
1.
,
0.
,
1.
]))
assert
th
.
allclose
(
h
,
th
.
tensor
([
3.
,
0.
,
3.
,
3.
,
2.
,
0.
,
1.
,
1.
,
0.
,
1.
]))
assert
check_eq
(
l
,
assert
th
.
allclose
(
l
,
th
.
tensor
([
0.
,
0.
,
1.
,
1.
,
1.
,
1.
,
0.
,
0.
,
0.
,
3.
,
1.
,
4.
,
1.
,
4.
,
0.
,
3.
,
1.
]))
th
.
tensor
([
0.
,
0.
,
1.
,
1.
,
1.
,
1.
,
0.
,
0.
,
0.
,
3.
,
1.
,
4.
,
1.
,
4.
,
0.
,
3.
,
1.
]))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_basics
()
test_basics
()
test_merge
()
#
test_merge()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment