Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
421763fb
Unverified
Commit
421763fb
authored
Jun 23, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Jun 23, 2020
Browse files
[Bugfix] Fix #1641 (#1678)
* fix #1641 * lint
parent
fc7cd275
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
12 deletions
+34
-12
src/graph/unit_graph.cc
src/graph/unit_graph.cc
+5
-3
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+29
-9
No files found.
src/graph/unit_graph.cc
View file @
421763fb
...
...
@@ -376,10 +376,12 @@ class UnitGraph::COO : public BaseHeteroGraph {
}
else
{
IdArray
new_src
=
aten
::
IndexSelect
(
adj_
.
row
,
eids
[
0
]);
IdArray
new_dst
=
aten
::
IndexSelect
(
adj_
.
col
,
eids
[
0
]);
subg
.
induced_vertices
.
emplace_back
(
aten
::
Range
(
0
,
NumVertices
(
0
),
NumBits
(),
Context
()));
subg
.
induced_vertices
.
emplace_back
(
aten
::
Range
(
0
,
NumVertices
(
1
),
NumBits
(),
Context
()));
subg
.
induced_vertices
.
emplace_back
(
aten
::
Range
(
0
,
NumVertices
(
SrcType
()),
NumBits
(),
Context
()));
subg
.
induced_vertices
.
emplace_back
(
aten
::
Range
(
0
,
NumVertices
(
DstType
()),
NumBits
(),
Context
()));
subg
.
graph
=
std
::
make_shared
<
COO
>
(
meta_graph
(),
NumVertices
(
0
),
NumVertices
(
1
),
new_src
,
new_dst
);
meta_graph
(),
NumVertices
(
SrcType
()
),
NumVertices
(
DstType
()
),
new_src
,
new_dst
);
subg
.
induced_edges
=
eids
;
}
return
subg
;
...
...
tests/compute/test_heterograph.py
View file @
421763fb
...
...
@@ -933,32 +933,52 @@ def test_subgraph(index_dtype):
sg2
=
g
.
edge_subgraph
({
'follows'
:
[
1
],
'plays'
:
[
1
],
'wishes'
:
[
1
]})
_check_subgraph
(
g
,
sg2
)
def
_check_subgraph_single_ntype
(
g
,
sg
):
def
_check_subgraph_single_ntype
(
g
,
sg
,
preserve_nodes
=
False
):
assert
sg
.
ntypes
==
g
.
ntypes
assert
sg
.
etypes
==
g
.
etypes
assert
sg
.
canonical_etypes
==
g
.
canonical_etypes
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
nodes
[
'user'
].
data
[
dgl
.
NID
]),
F
.
tensor
([
1
,
2
],
F
.
int64
))
if
not
preserve_nodes
:
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
nodes
[
'user'
].
data
[
dgl
.
NID
]),
F
.
tensor
([
1
,
2
],
F
.
int64
))
else
:
for
ntype
in
sg
.
ntypes
:
assert
g
.
number_of_nodes
(
ntype
)
==
sg
.
number_of_nodes
(
ntype
)
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
edges
[
'follows'
].
data
[
dgl
.
EID
]),
F
.
tensor
([
1
],
F
.
int64
))
assert
F
.
array_equal
(
sg
.
nodes
[
'user'
].
data
[
'h'
],
g
.
nodes
[
'user'
].
data
[
'h'
][
1
:
3
])
if
not
preserve_nodes
:
assert
F
.
array_equal
(
sg
.
nodes
[
'user'
].
data
[
'h'
],
g
.
nodes
[
'user'
].
data
[
'h'
][
1
:
3
])
assert
F
.
array_equal
(
sg
.
edges
[
'follows'
].
data
[
'h'
],
g
.
edges
[
'follows'
].
data
[
'h'
][
1
:
2
])
def
_check_subgraph_single_etype
(
g
,
sg
):
def
_check_subgraph_single_etype
(
g
,
sg
,
preserve_nodes
=
False
):
assert
sg
.
ntypes
==
g
.
ntypes
assert
sg
.
etypes
==
g
.
etypes
assert
sg
.
canonical_etypes
==
g
.
canonical_etypes
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
nodes
[
'user'
].
data
[
dgl
.
NID
]),
F
.
tensor
([
0
,
1
],
F
.
int64
))
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
nodes
[
'game'
].
data
[
dgl
.
NID
]),
F
.
tensor
([
0
],
F
.
int64
))
if
not
preserve_nodes
:
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
nodes
[
'user'
].
data
[
dgl
.
NID
]),
F
.
tensor
([
0
,
1
],
F
.
int64
))
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
nodes
[
'game'
].
data
[
dgl
.
NID
]),
F
.
tensor
([
0
],
F
.
int64
))
else
:
for
ntype
in
sg
.
ntypes
:
assert
g
.
number_of_nodes
(
ntype
)
==
sg
.
number_of_nodes
(
ntype
)
assert
F
.
array_equal
(
F
.
tensor
(
sg
.
edges
[
'plays'
].
data
[
dgl
.
EID
]),
F
.
tensor
([
0
,
1
],
F
.
int64
))
sg1_graph
=
g_graph
.
subgraph
([
1
,
2
])
_check_subgraph_single_ntype
(
g_graph
,
sg1_graph
)
sg1_graph
=
g_graph
.
edge_subgraph
([
1
])
_check_subgraph_single_ntype
(
g_graph
,
sg1_graph
)
sg1_graph
=
g_graph
.
edge_subgraph
([
1
],
preserve_nodes
=
True
)
_check_subgraph_single_ntype
(
g_graph
,
sg1_graph
,
True
)
sg2_bipartite
=
g_bipartite
.
edge_subgraph
([
0
,
1
])
_check_subgraph_single_etype
(
g_bipartite
,
sg2_bipartite
)
sg2_bipartite
=
g_bipartite
.
edge_subgraph
([
0
,
1
],
preserve_nodes
=
True
)
_check_subgraph_single_etype
(
g_bipartite
,
sg2_bipartite
,
True
)
def
_check_typed_subgraph1
(
g
,
sg
):
assert
set
(
sg
.
ntypes
)
==
{
'user'
,
'game'
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment