Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8ab27b05
Unverified
Commit
8ab27b05
authored
Mar 08, 2024
by
Rhett Ying
Committed by
GitHub
Mar 08, 2024
Browse files
[DistGB] enable replacement sampling with GraphBolt API (#7202)
parent
b0982feb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
9 deletions
+29
-9
python/dgl/distributed/graph_services.py
python/dgl/distributed/graph_services.py
+4
-5
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+25
-4
No files found.
python/dgl/distributed/graph_services.py
View file @
8ab27b05
...
@@ -130,13 +130,10 @@ def _sample_neighbors_graphbolt(
...
@@ -130,13 +130,10 @@ def _sample_neighbors_graphbolt(
nodes
=
nodes
.
to
(
dtype
=
g
.
indices
.
dtype
)
nodes
=
nodes
.
to
(
dtype
=
g
.
indices
.
dtype
)
# 2. Perform sampling.
# 2. Perform sampling.
# [Rui][TODO] `prob`
and `replace` are
not tested yet. Skip for now.
# [Rui][TODO] `prob`
is
not tested yet. Skip for now.
assert
(
assert
(
prob
is
None
prob
is
None
),
"DistGraphBolt does not support sampling with probability."
),
"DistGraphBolt does not support sampling with probability."
assert
(
not
replace
),
"DistGraphBolt does not support sampling with replacement."
# Sanity checks.
# Sanity checks.
assert
isinstance
(
assert
isinstance
(
...
@@ -148,7 +145,9 @@ def _sample_neighbors_graphbolt(
...
@@ -148,7 +145,9 @@ def _sample_neighbors_graphbolt(
assert
isinstance
(
fanout
,
torch
.
Tensor
),
"Expect a tensor of fanout."
assert
isinstance
(
fanout
,
torch
.
Tensor
),
"Expect a tensor of fanout."
return_eids
=
g
.
edge_attributes
is
not
None
and
EID
in
g
.
edge_attributes
return_eids
=
g
.
edge_attributes
is
not
None
and
EID
in
g
.
edge_attributes
subgraph
=
g
.
_sample_neighbors
(
nodes
,
fanout
,
return_eids
=
return_eids
)
subgraph
=
g
.
_sample_neighbors
(
nodes
,
fanout
,
replace
=
replace
,
return_eids
=
return_eids
)
# 3. Map local node IDs to global node IDs.
# 3. Map local node IDs to global node IDs.
local_src
=
subgraph
.
indices
local_src
=
subgraph
.
indices
...
...
tests/distributed/test_distributed_sampling.py
View file @
8ab27b05
...
@@ -80,6 +80,7 @@ def start_sample_client_shuffle(
...
@@ -80,6 +80,7 @@ def start_sample_client_shuffle(
use_graphbolt
=
False
,
use_graphbolt
=
False
,
return_eids
=
False
,
return_eids
=
False
,
node_id_dtype
=
None
,
node_id_dtype
=
None
,
replace
=
False
,
):
):
os
.
environ
[
"DGL_GROUP_ID"
]
=
str
(
group_id
)
os
.
environ
[
"DGL_GROUP_ID"
]
=
str
(
group_id
)
gpb
=
None
gpb
=
None
...
@@ -93,6 +94,7 @@ def start_sample_client_shuffle(
...
@@ -93,6 +94,7 @@ def start_sample_client_shuffle(
dist_graph
,
dist_graph
,
torch
.
tensor
([
0
,
10
,
99
,
66
,
1024
,
2008
],
dtype
=
node_id_dtype
),
torch
.
tensor
([
0
,
10
,
99
,
66
,
1024
,
2008
],
dtype
=
node_id_dtype
),
3
,
3
,
replace
=
replace
,
use_graphbolt
=
use_graphbolt
,
use_graphbolt
=
use_graphbolt
,
)
)
assert
sampled_graph
.
idtype
==
dist_graph
.
idtype
assert
sampled_graph
.
idtype
==
dist_graph
.
idtype
...
@@ -102,6 +104,7 @@ def start_sample_client_shuffle(
...
@@ -102,6 +104,7 @@ def start_sample_client_shuffle(
dgl
.
ETYPE
not
in
sampled_graph
.
edata
dgl
.
ETYPE
not
in
sampled_graph
.
edata
),
"Etype should not be in homogeneous sampled graph."
),
"Etype should not be in homogeneous sampled graph."
src
,
dst
=
sampled_graph
.
edges
()
src
,
dst
=
sampled_graph
.
edges
()
sampled_in_degrees
=
sampled_graph
.
in_degrees
(
dst
)
src
=
orig_nid
[
src
]
src
=
orig_nid
[
src
]
dst
=
orig_nid
[
dst
]
dst
=
orig_nid
[
dst
]
assert
sampled_graph
.
num_nodes
()
==
g
.
num_nodes
()
assert
sampled_graph
.
num_nodes
()
==
g
.
num_nodes
()
...
@@ -114,6 +117,14 @@ def start_sample_client_shuffle(
...
@@ -114,6 +117,14 @@ def start_sample_client_shuffle(
eids
=
g
.
edge_ids
(
src
,
dst
)
eids
=
g
.
edge_ids
(
src
,
dst
)
eids1
=
orig_eid
[
sampled_graph
.
edata
[
dgl
.
EID
]]
eids1
=
orig_eid
[
sampled_graph
.
edata
[
dgl
.
EID
]]
assert
np
.
array_equal
(
F
.
asnumpy
(
eids1
),
F
.
asnumpy
(
eids
))
assert
np
.
array_equal
(
F
.
asnumpy
(
eids1
),
F
.
asnumpy
(
eids
))
# Verify replace argument.
orig_in_degrees
=
g
.
in_degrees
(
dst
)
if
replace
:
assert
torch
.
all
(
(
sampled_in_degrees
==
3
)
|
(
sampled_in_degrees
==
orig_in_degrees
)
)
else
:
assert
torch
.
all
(
sampled_in_degrees
<=
3
)
def
start_find_edges_client
(
rank
,
tmpdir
,
disable_shared_mem
,
eids
,
etype
=
None
):
def
start_find_edges_client
(
rank
,
tmpdir
,
disable_shared_mem
,
eids
,
etype
=
None
):
...
@@ -408,6 +419,7 @@ def check_rpc_sampling_shuffle(
...
@@ -408,6 +419,7 @@ def check_rpc_sampling_shuffle(
use_graphbolt
=
False
,
use_graphbolt
=
False
,
return_eids
=
False
,
return_eids
=
False
,
node_id_dtype
=
None
,
node_id_dtype
=
None
,
replace
=
False
,
):
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
@@ -463,6 +475,7 @@ def check_rpc_sampling_shuffle(
...
@@ -463,6 +475,7 @@ def check_rpc_sampling_shuffle(
use_graphbolt
,
use_graphbolt
,
return_eids
,
return_eids
,
node_id_dtype
,
node_id_dtype
,
replace
,
),
),
)
)
p
.
start
()
p
.
start
()
...
@@ -483,6 +496,7 @@ def start_hetero_sample_client(
...
@@ -483,6 +496,7 @@ def start_hetero_sample_client(
nodes
,
nodes
,
use_graphbolt
=
False
,
use_graphbolt
=
False
,
return_eids
=
False
,
return_eids
=
False
,
replace
=
False
,
):
):
gpb
=
None
gpb
=
None
if
disable_shared_mem
:
if
disable_shared_mem
:
...
@@ -503,7 +517,7 @@ def start_hetero_sample_client(
...
@@ -503,7 +517,7 @@ def start_hetero_sample_client(
# Enable santity check in distributed sampling.
# Enable santity check in distributed sampling.
os
.
environ
[
"DGL_DIST_DEBUG"
]
=
"1"
os
.
environ
[
"DGL_DIST_DEBUG"
]
=
"1"
sampled_graph
=
sample_neighbors
(
sampled_graph
=
sample_neighbors
(
dist_graph
,
nodes
,
3
,
use_graphbolt
=
use_graphbolt
dist_graph
,
nodes
,
3
,
replace
=
replace
,
use_graphbolt
=
use_graphbolt
)
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
if
not
use_graphbolt
or
return_eids
:
if
not
use_graphbolt
or
return_eids
:
...
@@ -573,7 +587,7 @@ def start_hetero_etype_sample_client(
...
@@ -573,7 +587,7 @@ def start_hetero_etype_sample_client(
def
check_rpc_hetero_sampling_shuffle
(
def
check_rpc_hetero_sampling_shuffle
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
,
replace
=
False
):
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
@@ -619,6 +633,7 @@ def check_rpc_hetero_sampling_shuffle(
...
@@ -619,6 +633,7 @@ def check_rpc_hetero_sampling_shuffle(
nodes
=
nodes
,
nodes
=
nodes
,
use_graphbolt
=
use_graphbolt
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
return_eids
=
return_eids
,
replace
=
replace
,
)
)
for
p
in
pserver_list
:
for
p
in
pserver_list
:
p
.
join
()
p
.
join
()
...
@@ -1247,8 +1262,9 @@ def check_rpc_bipartite_etype_sampling_shuffle(
...
@@ -1247,8 +1262,9 @@ def check_rpc_bipartite_etype_sampling_shuffle(
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"node_id_dtype"
,
[
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"node_id_dtype"
,
[
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"replace"
,
[
False
,
True
])
def
test_rpc_sampling_shuffle
(
def
test_rpc_sampling_shuffle
(
num_server
,
use_graphbolt
,
return_eids
,
node_id_dtype
num_server
,
use_graphbolt
,
return_eids
,
node_id_dtype
,
replace
):
):
reset_envs
()
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
...
@@ -1259,13 +1275,17 @@ def test_rpc_sampling_shuffle(
...
@@ -1259,13 +1275,17 @@ def test_rpc_sampling_shuffle(
use_graphbolt
=
use_graphbolt
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
return_eids
=
return_eids
,
node_id_dtype
=
node_id_dtype
,
node_id_dtype
=
node_id_dtype
,
replace
=
replace
,
)
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt,"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt,"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_hetero_sampling_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
@
pytest
.
mark
.
parametrize
(
"replace"
,
[
False
,
True
])
def
test_rpc_hetero_sampling_shuffle
(
num_server
,
use_graphbolt
,
return_eids
,
replace
):
reset_envs
()
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
...
@@ -1274,6 +1294,7 @@ def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids):
...
@@ -1274,6 +1294,7 @@ def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids):
num_server
,
num_server
,
use_graphbolt
=
use_graphbolt
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
return_eids
=
return_eids
,
replace
=
replace
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment