Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6694f7b9
Unverified
Commit
6694f7b9
authored
Nov 25, 2019
by
Mufei Li
Committed by
GitHub
Nov 25, 2019
Browse files
[Bug] Fix DGLHeteroGraph.edge_type_subgraph (#1040)
* Update * Try CI
parent
168fc2cf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
9 deletions
+22
-9
python/dgl/heterograph.py
python/dgl/heterograph.py
+5
-5
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+17
-4
No files found.
python/dgl/heterograph.py
View file @
6694f7b9
...
...
@@ -1752,12 +1752,12 @@ class DGLHeteroGraph(object):
rel_graphs
=
[
self
.
_graph
.
get_relation_graph
(
i
)
for
i
in
etype_ids
]
meta_src
=
meta_src
.
tonumpy
()
meta_dst
=
meta_dst
.
tonumpy
()
induced_
ntype_i
ds
=
list
(
set
(
meta_src
)
|
set
(
meta_dst
))
mapped_meta_src
=
[
induced_
ntype_i
ds
[
v
]
for
v
in
meta_src
]
mapped_meta_dst
=
[
induced_
ntype_i
ds
[
v
]
for
v
in
meta_dst
]
node_frames
=
[
self
.
_node_frames
[
i
]
for
i
in
induced_
ntype_i
ds
]
ntype
s
_i
nvmap
=
{
n
:
i
for
i
,
n
in
enumerate
(
set
(
meta_src
)
|
set
(
meta_dst
))
}
mapped_meta_src
=
[
ntype
s
_i
nvmap
[
v
]
for
v
in
meta_src
]
mapped_meta_dst
=
[
ntype
s
_i
nvmap
[
v
]
for
v
in
meta_dst
]
node_frames
=
[
self
.
_node_frames
[
i
]
for
i
in
ntype
s
_i
nvmap
]
edge_frames
=
[
self
.
_edge_frames
[
i
]
for
i
in
etype_ids
]
induced_ntypes
=
[
self
.
_ntypes
[
i
]
for
i
in
induced_
ntype_i
ds
]
induced_ntypes
=
[
self
.
_ntypes
[
i
]
for
i
in
ntype
s
_i
nvmap
]
induced_etypes
=
[
self
.
_etypes
[
i
]
for
i
in
etype_ids
]
# get the "name" of edge type
metagraph
=
graph_index
.
from_edge_list
((
mapped_meta_src
,
mapped_meta_dst
),
True
,
True
)
...
...
tests/compute/test_heterograph.py
View file @
6694f7b9
...
...
@@ -732,7 +732,7 @@ def test_subgraph():
sg2
=
g
.
edge_subgraph
({
'follows'
:
[
1
],
'plays'
:
[
1
],
'wishes'
:
[
1
]})
_check_subgraph
(
g
,
sg2
)
def
_check_typed_subgraph
(
g
,
sg
):
def
_check_typed_subgraph
1
(
g
,
sg
):
assert
set
(
sg
.
ntypes
)
==
{
'user'
,
'game'
}
assert
set
(
sg
.
etypes
)
==
{
'follows'
,
'plays'
,
'wishes'
}
for
ntype
in
sg
.
ntypes
:
...
...
@@ -749,10 +749,23 @@ def test_subgraph():
assert
F
.
array_equal
(
sg
.
nodes
[
'user'
].
data
[
'h'
],
g
.
nodes
[
'user'
].
data
[
'h'
])
assert
F
.
array_equal
(
sg
.
edges
[
'follows'
].
data
[
'h'
],
g
.
edges
[
'follows'
].
data
[
'h'
])
def
_check_typed_subgraph2
(
g
,
sg
):
assert
set
(
sg
.
ntypes
)
==
{
'developer'
,
'game'
}
assert
set
(
sg
.
etypes
)
==
{
'develops'
}
for
ntype
in
sg
.
ntypes
:
assert
sg
.
number_of_nodes
(
ntype
)
==
g
.
number_of_nodes
(
ntype
)
for
etype
in
sg
.
etypes
:
src_sg
,
dst_sg
=
sg
.
all_edges
(
etype
=
etype
,
order
=
'eid'
)
src_g
,
dst_g
=
g
.
all_edges
(
etype
=
etype
,
order
=
'eid'
)
assert
F
.
array_equal
(
src_sg
,
src_g
)
assert
F
.
array_equal
(
dst_sg
,
dst_g
)
sg3
=
g
.
node_type_subgraph
([
'user'
,
'game'
])
_check_typed_subgraph
(
g
,
sg3
)
sg4
=
g
.
edge_type_subgraph
([
'follows'
,
'plays'
,
'wishes'
])
_check_typed_subgraph
(
g
,
sg4
)
_check_typed_subgraph1
(
g
,
sg3
)
sg4
=
g
.
edge_type_subgraph
([
'develops'
])
_check_typed_subgraph2
(
g
,
sg4
)
sg5
=
g
.
edge_type_subgraph
([
'follows'
,
'plays'
,
'wishes'
])
_check_typed_subgraph1
(
g
,
sg5
)
def
test_apply
():
def
node_udf
(
nodes
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment