Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3e72c53a
Unverified
Commit
3e72c53a
authored
Jun 15, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Jun 15, 2020
Browse files
[Bug?] Fix #1563 (#1642)
parent
ba2ee7bd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
2 deletions
+47
-2
python/dgl/heterograph.py
python/dgl/heterograph.py
+17
-2
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+30
-0
No files found.
python/dgl/heterograph.py
View file @
3e72c53a
"""Classes for heterogeneous graphs."""
#pylint: disable= too-many-lines
from
collections
import
defaultdict
from
collections.abc
import
Mapping
from
contextlib
import
contextmanager
import
copy
import
networkx
as
nx
...
...
@@ -1892,10 +1893,13 @@ class DGLHeteroGraph(object):
Parameters
----------
nodes : dict[str->list or iterable]
nodes :
list or
dict[str->list or iterable]
A dictionary mapping node types to node ID array for constructing
subgraph. All nodes must exist in the graph.
If the graph only has one node type, one can just specify a list,
tensor, or any iterable of node IDs intead.
Returns
-------
G : DGLHeteroGraph
...
...
@@ -1952,7 +1956,11 @@ class DGLHeteroGraph(object):
--------
edge_subgraph
"""
check_same_dtype
(
self
.
_idtype_str
,
nodes
)
if
not
isinstance
(
nodes
,
Mapping
):
assert
len
(
self
.
ntypes
)
==
1
,
\
'need a dict of node type and IDs for graph with multiple node types'
nodes
=
{
self
.
ntypes
[
0
]:
nodes
}
check_idtype_dict
(
self
.
_idtype_str
,
nodes
)
induced_nodes
=
[
utils
.
toindex
(
nodes
.
get
(
ntype
,
[]),
self
.
_idtype_str
)
for
ntype
in
self
.
ntypes
]
sgi
=
self
.
_graph
.
node_subgraph
(
induced_nodes
)
...
...
@@ -1975,6 +1983,9 @@ class DGLHeteroGraph(object):
The edge types are characterized by triplets of
``(src type, etype, dst type)``.
If the graph only has one edge type, one can just specify a list,
tensor, or any iterable of edge IDs intead.
preserve_nodes : bool
Whether to preserve all nodes or not. If false, all nodes
without edges will be removed. (Default: False)
...
...
@@ -2035,6 +2046,10 @@ class DGLHeteroGraph(object):
--------
subgraph
"""
if
not
isinstance
(
edges
,
Mapping
):
assert
len
(
self
.
canonical_etypes
)
==
1
,
\
'need a dict of edge type and IDs for graph with multiple edge types'
edges
=
{
self
.
canonical_etypes
[
0
]:
edges
}
check_idtype_dict
(
self
.
_idtype_str
,
edges
)
edges
=
{
self
.
to_canonical_etype
(
etype
):
e
for
etype
,
e
in
edges
.
items
()}
induced_edges
=
[
...
...
tests/compute/test_heterograph.py
View file @
3e72c53a
...
...
@@ -898,6 +898,9 @@ def test_transform(index_dtype):
@
parametrize_dtype
def
test_subgraph
(
index_dtype
):
g
=
create_test_heterograph
(
index_dtype
)
g_graph
=
g
[
'follows'
]
g_bipartite
=
g
[
'plays'
]
x
=
F
.
randn
((
3
,
5
))
y
=
F
.
randn
((
2
,
4
))
g
.
nodes
[
'user'
].
data
[
'h'
]
=
x
...
...
@@ -927,6 +930,33 @@ def test_subgraph(index_dtype):
sg2
=
g
.
edge_subgraph
({
'follows'
:
[
1
],
'plays'
:
[
1
],
'wishes'
:
[
1
]})
_check_subgraph
(
g
,
sg2
)
def
_check_subgraph_single_ntype
(
g
,
sg
):
assert
sg
.
ntypes
==
g
.
ntypes
assert
sg
.
etypes
==
g
.
etypes
assert
sg
.
canonical_etypes
==
g
.
canonical_etypes
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
nodes
[
'user'
].
data
[
dgl
.
NID
]),
F
.
tensor
([
1
,
2
],
F
.
int64
))
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
edges
[
'follows'
].
data
[
dgl
.
EID
]),
F
.
tensor
([
1
],
F
.
int64
))
assert
F
.
array_equal
(
sg
.
nodes
[
'user'
].
data
[
'h'
],
g
.
nodes
[
'user'
].
data
[
'h'
][
1
:
3
])
assert
F
.
array_equal
(
sg
.
edges
[
'follows'
].
data
[
'h'
],
g
.
edges
[
'follows'
].
data
[
'h'
][
1
:
2
])
def
_check_subgraph_single_etype
(
g
,
sg
):
assert
sg
.
ntypes
==
g
.
ntypes
assert
sg
.
etypes
==
g
.
etypes
assert
sg
.
canonical_etypes
==
g
.
canonical_etypes
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
nodes
[
'user'
].
data
[
dgl
.
NID
]),
F
.
tensor
([
0
,
1
],
F
.
int64
))
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
nodes
[
'game'
].
data
[
dgl
.
NID
]),
F
.
tensor
([
0
],
F
.
int64
))
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
edges
[
'plays'
].
data
[
dgl
.
EID
]),
F
.
tensor
([
0
,
1
],
F
.
int64
))
sg1_graph
=
g_graph
.
subgraph
([
1
,
2
])
_check_subgraph_single_ntype
(
g_graph
,
sg1_graph
)
sg2_bipartite
=
g_bipartite
.
edge_subgraph
([
0
,
1
])
_check_subgraph_single_etype
(
g_bipartite
,
sg2_bipartite
)
def
_check_typed_subgraph1
(
g
,
sg
):
assert
set
(
sg
.
ntypes
)
==
{
'user'
,
'game'
}
assert
set
(
sg
.
etypes
)
==
{
'follows'
,
'plays'
,
'wishes'
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment