Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6459a688
Unverified
Commit
6459a688
authored
Feb 04, 2024
by
Rhett Ying
Committed by
GitHub
Feb 04, 2024
Browse files
[DistGB] enable GB sampling on homograph (#7061)
parent
9273387e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
157 additions
and
13 deletions
+157
-13
python/dgl/distributed/graph_services.py
python/dgl/distributed/graph_services.py
+131
-7
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+26
-6
No files found.
python/dgl/distributed/graph_services.py
View file @
6459a688
...
...
@@ -3,7 +3,9 @@ from collections import namedtuple
import
numpy
as
np
from
..
import
backend
as
F
import
torch
from
..
import
backend
as
F
,
graphbolt
as
gb
from
..base
import
EID
,
NID
from
..convert
import
graph
,
heterograph
from
..sampling
import
(
...
...
@@ -65,6 +67,81 @@ class FindEdgeResponse(Response):
return
self
.
global_src
,
self
.
global_dst
,
self
.
order_id
def
_sample_neighbors_graphbolt
(
g
,
gpb
,
nodes
,
fanout
,
prob
=
None
,
replace
=
False
):
"""Sample from local partition via graphbolt.
The input nodes use global IDs. We need to map the global node IDs to local
node IDs, perform sampling and map the sampled results to the global IDs
space again. The sampled results are stored in three vectors that store
source nodes, destination nodes, etype IDs and edge IDs.
[Rui][TODO] edge IDs are not returned as not supported yet.
Parameters
----------
g : FusedCSCSamplingGraph
The local partition.
gpb : GraphPartitionBook
The graph partition book.
nodes : tensor
The nodes to sample neighbors from.
fanout : tensor or int
The number of edges to be sampled for each node.
prob : tensor, optional
The probability associated with each neighboring edge of a node.
replace : bool, optional
If True, sample with replacement.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge type ID array.
tensor
The edge ID array.
"""
# 1. Map global node IDs to local node IDs.
nodes
=
gpb
.
nid2localnid
(
nodes
,
gpb
.
partid
)
# 2. Perform sampling.
# [Rui][TODO] `prob` and `replace` are not tested yet. Skip for now.
assert
(
prob
is
None
),
"DistGraphBolt does not support sampling with probability."
assert
(
not
replace
),
"DistGraphBolt does not support sampling with replacement."
# Sanity checks.
assert
isinstance
(
g
,
gb
.
FusedCSCSamplingGraph
),
"Expect a FusedCSCSamplingGraph."
assert
isinstance
(
nodes
,
torch
.
Tensor
),
"Expect a tensor of nodes."
if
isinstance
(
fanout
,
int
):
fanout
=
torch
.
LongTensor
([
fanout
])
assert
isinstance
(
fanout
,
torch
.
Tensor
),
"Expect a tensor of fanout."
# [Rui][TODO] Support multiple fanouts.
assert
fanout
.
numel
()
==
1
,
"Expect a single fanout."
subgraph
=
g
.
_sample_neighbors
(
nodes
,
fanout
)
# 3. Map local node IDs to global node IDs.
local_src
=
subgraph
.
indices
local_dst
=
torch
.
repeat_interleave
(
subgraph
.
original_column_node_ids
,
torch
.
diff
(
subgraph
.
indptr
)
)
global_nid_mapping
=
g
.
node_attributes
[
NID
]
global_src
=
global_nid_mapping
[
local_src
]
global_dst
=
global_nid_mapping
[
local_dst
]
return
global_src
,
global_dst
,
subgraph
.
type_per_edge
def
_sample_neighbors
(
local_g
,
partition_book
,
seed_nodes
,
fan_out
,
edge_dir
,
prob
,
replace
):
...
...
@@ -212,12 +289,21 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
class
SamplingRequest
(
Request
):
"""Sampling Request"""
def
__init__
(
self
,
nodes
,
fan_out
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
):
def
__init__
(
self
,
nodes
,
fan_out
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
,
use_graphbolt
=
False
,
):
self
.
seed_nodes
=
nodes
self
.
edge_dir
=
edge_dir
self
.
prob
=
prob
self
.
replace
=
replace
self
.
fan_out
=
fan_out
self
.
use_graphbolt
=
use_graphbolt
def
__setstate__
(
self
,
state
):
(
...
...
@@ -226,6 +312,7 @@ class SamplingRequest(Request):
self
.
prob
,
self
.
replace
,
self
.
fan_out
,
self
.
use_graphbolt
,
)
=
state
def
__getstate__
(
self
):
...
...
@@ -235,6 +322,7 @@ class SamplingRequest(Request):
self
.
prob
,
self
.
replace
,
self
.
fan_out
,
self
.
use_graphbolt
,
)
def
process_request
(
self
,
server_state
):
...
...
@@ -245,6 +333,16 @@ class SamplingRequest(Request):
prob
=
[
kv_store
.
data_store
[
self
.
prob
]]
else
:
prob
=
None
if
self
.
use_graphbolt
:
global_src
,
global_dst
,
etype_ids
=
_sample_neighbors_graphbolt
(
local_g
,
partition_book
,
self
.
seed_nodes
,
self
.
fan_out
,
prob
,
self
.
replace
,
)
return
SubgraphResponse
(
global_src
,
global_dst
,
etype_ids
)
global_src
,
global_dst
,
global_eids
=
_sample_neighbors
(
local_g
,
partition_book
,
...
...
@@ -449,12 +547,13 @@ def merge_graphs(res_list, num_nodes):
eids
.
append
(
res
.
global_eids
)
src_tensor
=
F
.
cat
(
srcs
,
0
)
dst_tensor
=
F
.
cat
(
dsts
,
0
)
eid_tensor
=
F
.
cat
(
eids
,
0
)
eid_tensor
=
None
if
eids
[
0
]
is
None
else
F
.
cat
(
eids
,
0
)
else
:
src_tensor
=
res_list
[
0
].
global_src
dst_tensor
=
res_list
[
0
].
global_dst
eid_tensor
=
res_list
[
0
].
global_eids
g
=
graph
((
src_tensor
,
dst_tensor
),
num_nodes
=
num_nodes
)
if
eid_tensor
is
not
None
:
g
.
edata
[
EID
]
=
eid_tensor
return
g
...
...
@@ -491,6 +590,7 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
"""
req_list
=
[]
partition_book
=
g
.
get_partition_book
()
if
not
isinstance
(
nodes
,
torch
.
Tensor
):
nodes
=
toindex
(
nodes
).
tousertensor
()
partition_id
=
partition_book
.
nid2partid
(
nodes
)
local_nids
=
None
...
...
@@ -721,7 +821,15 @@ def sample_etype_neighbors(
return
frontier
def
sample_neighbors
(
g
,
nodes
,
fanout
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
):
def
sample_neighbors
(
g
,
nodes
,
fanout
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
,
use_graphbolt
=
False
,
):
"""Sample from the neighbors of the given nodes from a distributed graph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
...
...
@@ -764,6 +872,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
For sampling without replacement, if fanout > the number of neighbors, all the
neighbors are sampled. If fanout == -1, all neighbors are collected.
use_graphbolt : bool, optional
Whether to use GraphBolt for sampling.
Returns
-------
...
...
@@ -795,12 +905,26 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
else
:
_prob
=
None
return
SamplingRequest
(
node_ids
,
fanout
,
edge_dir
=
edge_dir
,
prob
=
_prob
,
replace
=
replace
node_ids
,
fanout
,
edge_dir
=
edge_dir
,
prob
=
_prob
,
replace
=
replace
,
use_graphbolt
=
use_graphbolt
,
)
def
local_access
(
local_g
,
partition_book
,
local_nids
):
# See NOTE 1
_prob
=
[
g
.
edata
[
prob
].
local_partition
]
if
prob
is
not
None
else
None
if
use_graphbolt
:
return
_sample_neighbors_graphbolt
(
local_g
,
partition_book
,
local_nids
,
fanout
,
prob
=
_prob
,
replace
=
replace
,
)
return
_sample_neighbors
(
local_g
,
partition_book
,
...
...
tests/distributed/test_distributed_sampling.py
View file @
6459a688
...
...
@@ -31,6 +31,7 @@ def start_server(
disable_shared_mem
,
graph_name
,
graph_format
=
[
"csc"
,
"coo"
],
use_graphbolt
=
False
,
):
g
=
DistGraphServer
(
rank
,
...
...
@@ -40,6 +41,7 @@ def start_server(
tmpdir
/
(
graph_name
+
".json"
),
disable_shared_mem
=
disable_shared_mem
,
graph_format
=
graph_format
,
use_graphbolt
=
use_graphbolt
,
)
g
.
start
()
...
...
@@ -72,6 +74,7 @@ def start_sample_client_shuffle(
group_id
,
orig_nid
,
orig_eid
,
use_graphbolt
=
False
,
):
os
.
environ
[
"DGL_GROUP_ID"
]
=
str
(
group_id
)
gpb
=
None
...
...
@@ -80,14 +83,23 @@ def start_sample_client_shuffle(
tmpdir
/
"test_sampling.json"
,
rank
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
sampled_graph
=
sample_neighbors
(
dist_graph
,
[
0
,
10
,
99
,
66
,
1024
,
2008
],
3
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
sampled_graph
=
sample_neighbors
(
dist_graph
,
[
0
,
10
,
99
,
66
,
1024
,
2008
],
3
,
use_graphbolt
=
use_graphbolt
)
src
,
dst
=
sampled_graph
.
edges
()
src
=
orig_nid
[
src
]
dst
=
orig_nid
[
dst
]
assert
sampled_graph
.
num_nodes
()
==
g
.
num_nodes
()
assert
np
.
all
(
F
.
asnumpy
(
g
.
has_edges_between
(
src
,
dst
)))
if
use_graphbolt
:
assert
(
dgl
.
EID
not
in
sampled_graph
.
edata
),
"EID should not be in sampled graph if use_graphbolt=True."
else
:
eids
=
g
.
edge_ids
(
src
,
dst
)
eids1
=
orig_eid
[
sampled_graph
.
edata
[
dgl
.
EID
]]
assert
np
.
array_equal
(
F
.
asnumpy
(
eids1
),
F
.
asnumpy
(
eids
))
...
...
@@ -378,7 +390,9 @@ def test_rpc_sampling():
check_rpc_sampling
(
Path
(
tmpdirname
),
1
)
def
check_rpc_sampling_shuffle
(
tmpdir
,
num_server
,
num_groups
=
1
):
def
check_rpc_sampling_shuffle
(
tmpdir
,
num_server
,
num_groups
=
1
,
use_graphbolt
=
False
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
g
=
CitationGraphDataset
(
"cora"
)[
0
]
...
...
@@ -393,6 +407,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_hops
=
num_hops
,
part_method
=
"metis"
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
)
pserver_list
=
[]
...
...
@@ -406,6 +421,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
p
.
start
()
...
...
@@ -427,6 +443,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
group_id
,
orig_nids
,
orig_eids
,
use_graphbolt
,
),
)
p
.
start
()
...
...
@@ -1012,6 +1029,9 @@ def test_rpc_sampling_shuffle(num_server):
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
True
)
check_rpc_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
# [TODO][Rhett] Tests for multiple groups may fail sometimes and
# root cause is unknown. Let's disable them for now.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment