Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
43ba94ee
Unverified
Commit
43ba94ee
authored
Aug 03, 2022
by
Rhett Ying
Committed by
GitHub
Aug 03, 2022
Browse files
[BugFix] fix etype check in DistGraph.edge_subgraph (#4322)
parent
463650a7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
1 deletion
+14
-1
python/dgl/distributed/dist_graph.py
python/dgl/distributed/dist_graph.py
+1
-1
tests/distributed/test_dist_graph_store.py
tests/distributed/test_dist_graph_store.py
+13
-0
No files found.
python/dgl/distributed/dist_graph.py
View file @
43ba94ee
...
...
@@ -1180,7 +1180,7 @@ class DistGraph:
if
isinstance
(
edges
,
dict
):
# TODO(zhengda) we need to directly generate subgraph of all relations with
# one invocation.
if
isinstance
(
edges
,
tuple
):
if
isinstance
(
list
(
edges
.
keys
())[
0
]
,
tuple
):
subg
=
{
etype
:
self
.
find_edges
(
edges
[
etype
],
etype
[
1
])
for
etype
in
edges
}
else
:
subg
=
{}
...
...
tests/distributed/test_dist_graph_store.py
View file @
43ba94ee
...
...
@@ -228,6 +228,11 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
feats
=
F
.
squeeze
(
feats1
,
1
)
assert
np
.
all
(
F
.
asnumpy
(
feats
==
eids
))
# Test edge_subgraph
sg
=
g
.
edge_subgraph
(
eids
)
assert
sg
.
num_edges
()
==
len
(
eids
)
assert
F
.
array_equal
(
sg
.
edata
[
dgl
.
EID
],
eids
)
# Test init node data
new_shape
=
(
g
.
number_of_nodes
(),
2
)
test1
=
dgl
.
distributed
.
DistTensor
(
new_shape
,
F
.
int32
)
...
...
@@ -494,6 +499,14 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
feats
=
F
.
squeeze
(
feats1
,
1
)
assert
np
.
all
(
F
.
asnumpy
(
feats
==
eids
))
# Test edge_subgraph
sg
=
g
.
edge_subgraph
({
'r1'
:
eids
})
assert
sg
.
num_edges
()
==
len
(
eids
)
assert
F
.
array_equal
(
sg
.
edata
[
dgl
.
EID
],
eids
)
sg
=
g
.
edge_subgraph
({(
'n1'
,
'r1'
,
'n2'
):
eids
})
assert
sg
.
num_edges
()
==
len
(
eids
)
assert
F
.
array_equal
(
sg
.
edata
[
dgl
.
EID
],
eids
)
# Test init node data
new_shape
=
(
g
.
number_of_nodes
(
'n1'
),
2
)
g
.
nodes
[
'n1'
].
data
[
'test1'
]
=
dgl
.
distributed
.
DistTensor
(
new_shape
,
F
.
int32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment