Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3e72c53a
Unverified
Commit
3e72c53a
authored
Jun 15, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Jun 15, 2020
Browse files
[Bug?] Fix #1563 (#1642)
parent
ba2ee7bd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
2 deletions
+47
-2
python/dgl/heterograph.py
python/dgl/heterograph.py
+17
-2
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+30
-0
No files found.
python/dgl/heterograph.py
View file @
3e72c53a
"""Classes for heterogeneous graphs."""
"""Classes for heterogeneous graphs."""
#pylint: disable= too-many-lines
#pylint: disable= too-many-lines
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Mapping
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
copy
import
copy
import
networkx
as
nx
import
networkx
as
nx
...
@@ -1892,10 +1893,13 @@ class DGLHeteroGraph(object):
...
@@ -1892,10 +1893,13 @@ class DGLHeteroGraph(object):
Parameters
Parameters
----------
----------
nodes : dict[str->list or iterable]
nodes :
list or
dict[str->list or iterable]
A dictionary mapping node types to node ID array for constructing
A dictionary mapping node types to node ID array for constructing
subgraph. All nodes must exist in the graph.
subgraph. All nodes must exist in the graph.
If the graph only has one node type, one can just specify a list,
tensor, or any iterable of node IDs intead.
Returns
Returns
-------
-------
G : DGLHeteroGraph
G : DGLHeteroGraph
...
@@ -1952,7 +1956,11 @@ class DGLHeteroGraph(object):
...
@@ -1952,7 +1956,11 @@ class DGLHeteroGraph(object):
--------
--------
edge_subgraph
edge_subgraph
"""
"""
check_same_dtype
(
self
.
_idtype_str
,
nodes
)
if
not
isinstance
(
nodes
,
Mapping
):
assert
len
(
self
.
ntypes
)
==
1
,
\
'need a dict of node type and IDs for graph with multiple node types'
nodes
=
{
self
.
ntypes
[
0
]:
nodes
}
check_idtype_dict
(
self
.
_idtype_str
,
nodes
)
induced_nodes
=
[
utils
.
toindex
(
nodes
.
get
(
ntype
,
[]),
self
.
_idtype_str
)
induced_nodes
=
[
utils
.
toindex
(
nodes
.
get
(
ntype
,
[]),
self
.
_idtype_str
)
for
ntype
in
self
.
ntypes
]
for
ntype
in
self
.
ntypes
]
sgi
=
self
.
_graph
.
node_subgraph
(
induced_nodes
)
sgi
=
self
.
_graph
.
node_subgraph
(
induced_nodes
)
...
@@ -1975,6 +1983,9 @@ class DGLHeteroGraph(object):
...
@@ -1975,6 +1983,9 @@ class DGLHeteroGraph(object):
The edge types are characterized by triplets of
The edge types are characterized by triplets of
``(src type, etype, dst type)``.
``(src type, etype, dst type)``.
If the graph only has one edge type, one can just specify a list,
tensor, or any iterable of edge IDs intead.
preserve_nodes : bool
preserve_nodes : bool
Whether to preserve all nodes or not. If false, all nodes
Whether to preserve all nodes or not. If false, all nodes
without edges will be removed. (Default: False)
without edges will be removed. (Default: False)
...
@@ -2035,6 +2046,10 @@ class DGLHeteroGraph(object):
...
@@ -2035,6 +2046,10 @@ class DGLHeteroGraph(object):
--------
--------
subgraph
subgraph
"""
"""
if
not
isinstance
(
edges
,
Mapping
):
assert
len
(
self
.
canonical_etypes
)
==
1
,
\
'need a dict of edge type and IDs for graph with multiple edge types'
edges
=
{
self
.
canonical_etypes
[
0
]:
edges
}
check_idtype_dict
(
self
.
_idtype_str
,
edges
)
check_idtype_dict
(
self
.
_idtype_str
,
edges
)
edges
=
{
self
.
to_canonical_etype
(
etype
):
e
for
etype
,
e
in
edges
.
items
()}
edges
=
{
self
.
to_canonical_etype
(
etype
):
e
for
etype
,
e
in
edges
.
items
()}
induced_edges
=
[
induced_edges
=
[
...
...
tests/compute/test_heterograph.py
View file @
3e72c53a
...
@@ -898,6 +898,9 @@ def test_transform(index_dtype):
...
@@ -898,6 +898,9 @@ def test_transform(index_dtype):
@
parametrize_dtype
@
parametrize_dtype
def
test_subgraph
(
index_dtype
):
def
test_subgraph
(
index_dtype
):
g
=
create_test_heterograph
(
index_dtype
)
g
=
create_test_heterograph
(
index_dtype
)
g_graph
=
g
[
'follows'
]
g_bipartite
=
g
[
'plays'
]
x
=
F
.
randn
((
3
,
5
))
x
=
F
.
randn
((
3
,
5
))
y
=
F
.
randn
((
2
,
4
))
y
=
F
.
randn
((
2
,
4
))
g
.
nodes
[
'user'
].
data
[
'h'
]
=
x
g
.
nodes
[
'user'
].
data
[
'h'
]
=
x
...
@@ -927,6 +930,33 @@ def test_subgraph(index_dtype):
...
@@ -927,6 +930,33 @@ def test_subgraph(index_dtype):
sg2
=
g
.
edge_subgraph
({
'follows'
:
[
1
],
'plays'
:
[
1
],
'wishes'
:
[
1
]})
sg2
=
g
.
edge_subgraph
({
'follows'
:
[
1
],
'plays'
:
[
1
],
'wishes'
:
[
1
]})
_check_subgraph
(
g
,
sg2
)
_check_subgraph
(
g
,
sg2
)
def
_check_subgraph_single_ntype
(
g
,
sg
):
assert
sg
.
ntypes
==
g
.
ntypes
assert
sg
.
etypes
==
g
.
etypes
assert
sg
.
canonical_etypes
==
g
.
canonical_etypes
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
nodes
[
'user'
].
data
[
dgl
.
NID
]),
F
.
tensor
([
1
,
2
],
F
.
int64
))
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
edges
[
'follows'
].
data
[
dgl
.
EID
]),
F
.
tensor
([
1
],
F
.
int64
))
assert
F
.
array_equal
(
sg
.
nodes
[
'user'
].
data
[
'h'
],
g
.
nodes
[
'user'
].
data
[
'h'
][
1
:
3
])
assert
F
.
array_equal
(
sg
.
edges
[
'follows'
].
data
[
'h'
],
g
.
edges
[
'follows'
].
data
[
'h'
][
1
:
2
])
def
_check_subgraph_single_etype
(
g
,
sg
):
assert
sg
.
ntypes
==
g
.
ntypes
assert
sg
.
etypes
==
g
.
etypes
assert
sg
.
canonical_etypes
==
g
.
canonical_etypes
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
nodes
[
'user'
].
data
[
dgl
.
NID
]),
F
.
tensor
([
0
,
1
],
F
.
int64
))
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
nodes
[
'game'
].
data
[
dgl
.
NID
]),
F
.
tensor
([
0
],
F
.
int64
))
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
edges
[
'plays'
].
data
[
dgl
.
EID
]),
F
.
tensor
([
0
,
1
],
F
.
int64
))
sg1_graph
=
g_graph
.
subgraph
([
1
,
2
])
_check_subgraph_single_ntype
(
g_graph
,
sg1_graph
)
sg2_bipartite
=
g_bipartite
.
edge_subgraph
([
0
,
1
])
_check_subgraph_single_etype
(
g_bipartite
,
sg2_bipartite
)
def
_check_typed_subgraph1
(
g
,
sg
):
def
_check_typed_subgraph1
(
g
,
sg
):
assert
set
(
sg
.
ntypes
)
==
{
'user'
,
'game'
}
assert
set
(
sg
.
ntypes
)
==
{
'user'
,
'game'
}
assert
set
(
sg
.
etypes
)
==
{
'follows'
,
'plays'
,
'wishes'
}
assert
set
(
sg
.
etypes
)
==
{
'follows'
,
'plays'
,
'wishes'
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment