Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4ee0a8bd
Unverified
Commit
4ee0a8bd
authored
Feb 05, 2024
by
Rhett Ying
Committed by
GitHub
Feb 05, 2024
Browse files
[DistGB] return global eids from GB sampling on homograph (#7085)
parent
badeaf19
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
7 deletions
+17
-7
python/dgl/distributed/graph_services.py
python/dgl/distributed/graph_services.py
+6
-3
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+11
-4
No files found.
python/dgl/distributed/graph_services.py
View file @
4ee0a8bd
...
@@ -145,7 +145,8 @@ def _sample_neighbors_graphbolt(
...
@@ -145,7 +145,8 @@ def _sample_neighbors_graphbolt(
# [Rui][TODO] Support multiple fanouts.
# [Rui][TODO] Support multiple fanouts.
assert
fanout
.
numel
()
==
1
,
"Expect a single fanout."
assert
fanout
.
numel
()
==
1
,
"Expect a single fanout."
subgraph
=
g
.
_sample_neighbors
(
nodes
,
fanout
)
return_eids
=
g
.
edge_attributes
is
not
None
and
EID
in
g
.
edge_attributes
subgraph
=
g
.
_sample_neighbors
(
nodes
,
fanout
,
return_eids
=
return_eids
)
# 3. Map local node IDs to global node IDs.
# 3. Map local node IDs to global node IDs.
local_src
=
subgraph
.
indices
local_src
=
subgraph
.
indices
...
@@ -156,9 +157,11 @@ def _sample_neighbors_graphbolt(
...
@@ -156,9 +157,11 @@ def _sample_neighbors_graphbolt(
global_src
=
global_nid_mapping
[
local_src
]
global_src
=
global_nid_mapping
[
local_src
]
global_dst
=
global_nid_mapping
[
local_dst
]
global_dst
=
global_nid_mapping
[
local_dst
]
# [Rui][TODO] edge IDs are not supported yet.
global_eids
=
None
if
return_eids
:
global_eids
=
g
.
edge_attributes
[
EID
][
subgraph
.
original_edge_ids
]
return
LocalSampledGraph
(
return
LocalSampledGraph
(
global_src
,
global_dst
,
None
,
subgraph
.
type_per_edge
global_src
,
global_dst
,
global_eids
,
subgraph
.
type_per_edge
)
)
...
...
tests/distributed/test_distributed_sampling.py
View file @
4ee0a8bd
...
@@ -75,6 +75,7 @@ def start_sample_client_shuffle(
...
@@ -75,6 +75,7 @@ def start_sample_client_shuffle(
orig_nid
,
orig_nid
,
orig_eid
,
orig_eid
,
use_graphbolt
=
False
,
use_graphbolt
=
False
,
return_eids
=
False
,
):
):
os
.
environ
[
"DGL_GROUP_ID"
]
=
str
(
group_id
)
os
.
environ
[
"DGL_GROUP_ID"
]
=
str
(
group_id
)
gpb
=
None
gpb
=
None
...
@@ -95,7 +96,7 @@ def start_sample_client_shuffle(
...
@@ -95,7 +96,7 @@ def start_sample_client_shuffle(
dst
=
orig_nid
[
dst
]
dst
=
orig_nid
[
dst
]
assert
sampled_graph
.
num_nodes
()
==
g
.
num_nodes
()
assert
sampled_graph
.
num_nodes
()
==
g
.
num_nodes
()
assert
np
.
all
(
F
.
asnumpy
(
g
.
has_edges_between
(
src
,
dst
)))
assert
np
.
all
(
F
.
asnumpy
(
g
.
has_edges_between
(
src
,
dst
)))
if
use_graphbolt
:
if
use_graphbolt
and
not
return_eids
:
assert
(
assert
(
dgl
.
EID
not
in
sampled_graph
.
edata
dgl
.
EID
not
in
sampled_graph
.
edata
),
"EID should not be in sampled graph if use_graphbolt=True."
),
"EID should not be in sampled graph if use_graphbolt=True."
...
@@ -391,7 +392,7 @@ def test_rpc_sampling():
...
@@ -391,7 +392,7 @@ def test_rpc_sampling():
def
check_rpc_sampling_shuffle
(
def
check_rpc_sampling_shuffle
(
tmpdir
,
num_server
,
num_groups
=
1
,
use_graphbolt
=
False
tmpdir
,
num_server
,
num_groups
=
1
,
use_graphbolt
=
False
,
return_eids
=
False
):
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
@@ -408,6 +409,7 @@ def check_rpc_sampling_shuffle(
...
@@ -408,6 +409,7 @@ def check_rpc_sampling_shuffle(
part_method
=
"metis"
,
part_method
=
"metis"
,
return_mapping
=
True
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
)
pserver_list
=
[]
pserver_list
=
[]
...
@@ -444,6 +446,7 @@ def check_rpc_sampling_shuffle(
...
@@ -444,6 +446,7 @@ def check_rpc_sampling_shuffle(
orig_nids
,
orig_nids
,
orig_eids
,
orig_eids
,
use_graphbolt
,
use_graphbolt
,
return_eids
,
),
),
)
)
p
.
start
()
p
.
start
()
...
@@ -1015,12 +1018,16 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
...
@@ -1015,12 +1018,16 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
def
test_rpc_sampling_shuffle
(
num_server
,
use_graphbolt
):
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_sampling_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_sampling_shuffle
(
check_rpc_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment