Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4dd16f5d
Unverified
Commit
4dd16f5d
authored
Aug 01, 2022
by
Rhett Ying
Committed by
GitHub
Aug 01, 2022
Browse files
[BugFix] enable DistGraph.find_edge() works with str or tuple of str (#4319)
parent
44b68641
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
9 deletions
+19
-9
python/dgl/distributed/dist_graph.py
python/dgl/distributed/dist_graph.py
+2
-1
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+17
-8
No files found.
python/dgl/distributed/dist_graph.py
View file @
4dd16f5d
...
...
@@ -1132,7 +1132,8 @@ class DistGraph:
gpb
=
self
.
get_partition_book
()
if
len
(
gpb
.
etypes
)
>
1
:
# if etype is a canonical edge type (str, str, str), extract the edge type
if
len
(
etype
)
==
3
:
if
isinstance
(
etype
,
tuple
):
assert
len
(
etype
)
==
3
,
'Invalid canonical etype: {}'
.
format
(
etype
)
etype
=
etype
[
1
]
edges
=
gpb
.
map_to_homo_eid
(
edges
,
etype
)
src
,
dst
=
dist_find_edges
(
self
,
edges
)
...
...
tests/distributed/test_distributed_sampling.py
View file @
4dd16f5d
...
...
@@ -160,9 +160,9 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server):
def
create_random_hetero
(
dense
=
False
,
empty
=
False
):
num_nodes
=
{
'n1'
:
210
,
'n2'
:
200
,
'n3'
:
220
}
if
dense
else
\
{
'n1'
:
1010
,
'n2'
:
1000
,
'n3'
:
1020
}
etypes
=
[(
'n1'
,
'r1'
,
'n2'
),
(
'n1'
,
'r
2
'
,
'n3'
),
(
'n2'
,
'r3'
,
'n3'
)]
etypes
=
[(
'n1'
,
'r1
2
'
,
'n2'
),
(
'n1'
,
'r
13
'
,
'n3'
),
(
'n2'
,
'r
2
3'
,
'n3'
)]
edges
=
{}
random
.
seed
(
42
)
for
etype
in
etypes
:
...
...
@@ -195,9 +195,18 @@ def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
time
.
sleep
(
1
)
pserver_list
.
append
(
p
)
eids
=
F
.
tensor
(
np
.
random
.
randint
(
g
.
number_of_edges
(
'r1'
),
size
=
100
))
u
,
v
=
g
.
find_edges
(
orig_eid
[
'r1'
][
eids
],
etype
=
'r1'
)
du
,
dv
=
start_find_edges_client
(
0
,
tmpdir
,
num_server
>
1
,
eids
,
etype
=
'r1'
)
eids
=
F
.
tensor
(
np
.
random
.
randint
(
g
.
num_edges
(
'r12'
),
size
=
100
))
expect_except
=
False
try
:
_
,
_
=
g
.
find_edges
(
orig_eid
[
'r12'
][
eids
],
etype
=
(
'n1'
,
'r12'
))
except
:
expect_except
=
True
assert
expect_except
u
,
v
=
g
.
find_edges
(
orig_eid
[
'r12'
][
eids
],
etype
=
'r12'
)
u1
,
v1
=
g
.
find_edges
(
orig_eid
[
'r12'
][
eids
],
etype
=
(
'n1'
,
'r12'
,
'n2'
))
assert
F
.
array_equal
(
u
,
u1
)
assert
F
.
array_equal
(
v
,
v1
)
du
,
dv
=
start_find_edges_client
(
0
,
tmpdir
,
num_server
>
1
,
eids
,
etype
=
'r12'
)
du
=
orig_nid
[
'n1'
][
du
]
dv
=
orig_nid
[
'n2'
][
dv
]
assert
F
.
array_equal
(
u
,
du
)
...
...
@@ -488,9 +497,9 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server, etype_sorted=Fal
for
p
in
pserver_list
:
p
.
join
()
src
,
dst
=
block
.
edges
(
etype
=
(
'n1'
,
'r
2
'
,
'n3'
))
src
,
dst
=
block
.
edges
(
etype
=
(
'n1'
,
'r
13
'
,
'n3'
))
assert
len
(
src
)
==
18
src
,
dst
=
block
.
edges
(
etype
=
(
'n2'
,
'r3'
,
'n3'
))
src
,
dst
=
block
.
edges
(
etype
=
(
'n2'
,
'r
2
3'
,
'n3'
))
assert
len
(
src
)
==
18
orig_nid_map
=
{
ntype
:
F
.
zeros
((
g
.
number_of_nodes
(
ntype
),),
dtype
=
F
.
int64
)
for
ntype
in
g
.
ntypes
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment