Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6459a688
"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "51651ecadc70cb4b254881e1211a92dea9174cdb"
Unverified
Commit
6459a688
authored
Feb 04, 2024
by
Rhett Ying
Committed by
GitHub
Feb 04, 2024
Browse files
[DistGB] enable GB sampling on homograph (#7061)
parent
9273387e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
157 additions
and
13 deletions
+157
-13
python/dgl/distributed/graph_services.py
python/dgl/distributed/graph_services.py
+131
-7
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+26
-6
No files found.
python/dgl/distributed/graph_services.py
View file @
6459a688
...
@@ -3,7 +3,9 @@ from collections import namedtuple
...
@@ -3,7 +3,9 @@ from collections import namedtuple
import
numpy
as
np
import
numpy
as
np
from
..
import
backend
as
F
import
torch
from
..
import
backend
as
F
,
graphbolt
as
gb
from
..base
import
EID
,
NID
from
..base
import
EID
,
NID
from
..convert
import
graph
,
heterograph
from
..convert
import
graph
,
heterograph
from
..sampling
import
(
from
..sampling
import
(
...
@@ -65,6 +67,81 @@ class FindEdgeResponse(Response):
...
@@ -65,6 +67,81 @@ class FindEdgeResponse(Response):
return
self
.
global_src
,
self
.
global_dst
,
self
.
order_id
return
self
.
global_src
,
self
.
global_dst
,
self
.
order_id
def
_sample_neighbors_graphbolt
(
g
,
gpb
,
nodes
,
fanout
,
prob
=
None
,
replace
=
False
):
"""Sample from local partition via graphbolt.
The input nodes use global IDs. We need to map the global node IDs to local
node IDs, perform sampling and map the sampled results to the global IDs
space again. The sampled results are stored in three vectors that store
source nodes, destination nodes, etype IDs and edge IDs.
[Rui][TODO] edge IDs are not returned as not supported yet.
Parameters
----------
g : FusedCSCSamplingGraph
The local partition.
gpb : GraphPartitionBook
The graph partition book.
nodes : tensor
The nodes to sample neighbors from.
fanout : tensor or int
The number of edges to be sampled for each node.
prob : tensor, optional
The probability associated with each neighboring edge of a node.
replace : bool, optional
If True, sample with replacement.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge type ID array.
tensor
The edge ID array.
"""
# 1. Map global node IDs to local node IDs.
nodes
=
gpb
.
nid2localnid
(
nodes
,
gpb
.
partid
)
# 2. Perform sampling.
# [Rui][TODO] `prob` and `replace` are not tested yet. Skip for now.
assert
(
prob
is
None
),
"DistGraphBolt does not support sampling with probability."
assert
(
not
replace
),
"DistGraphBolt does not support sampling with replacement."
# Sanity checks.
assert
isinstance
(
g
,
gb
.
FusedCSCSamplingGraph
),
"Expect a FusedCSCSamplingGraph."
assert
isinstance
(
nodes
,
torch
.
Tensor
),
"Expect a tensor of nodes."
if
isinstance
(
fanout
,
int
):
fanout
=
torch
.
LongTensor
([
fanout
])
assert
isinstance
(
fanout
,
torch
.
Tensor
),
"Expect a tensor of fanout."
# [Rui][TODO] Support multiple fanouts.
assert
fanout
.
numel
()
==
1
,
"Expect a single fanout."
subgraph
=
g
.
_sample_neighbors
(
nodes
,
fanout
)
# 3. Map local node IDs to global node IDs.
local_src
=
subgraph
.
indices
local_dst
=
torch
.
repeat_interleave
(
subgraph
.
original_column_node_ids
,
torch
.
diff
(
subgraph
.
indptr
)
)
global_nid_mapping
=
g
.
node_attributes
[
NID
]
global_src
=
global_nid_mapping
[
local_src
]
global_dst
=
global_nid_mapping
[
local_dst
]
return
global_src
,
global_dst
,
subgraph
.
type_per_edge
def
_sample_neighbors
(
def
_sample_neighbors
(
local_g
,
partition_book
,
seed_nodes
,
fan_out
,
edge_dir
,
prob
,
replace
local_g
,
partition_book
,
seed_nodes
,
fan_out
,
edge_dir
,
prob
,
replace
):
):
...
@@ -212,12 +289,21 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
...
@@ -212,12 +289,21 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
class
SamplingRequest
(
Request
):
class
SamplingRequest
(
Request
):
"""Sampling Request"""
"""Sampling Request"""
def
__init__
(
self
,
nodes
,
fan_out
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
):
def
__init__
(
self
,
nodes
,
fan_out
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
,
use_graphbolt
=
False
,
):
self
.
seed_nodes
=
nodes
self
.
seed_nodes
=
nodes
self
.
edge_dir
=
edge_dir
self
.
edge_dir
=
edge_dir
self
.
prob
=
prob
self
.
prob
=
prob
self
.
replace
=
replace
self
.
replace
=
replace
self
.
fan_out
=
fan_out
self
.
fan_out
=
fan_out
self
.
use_graphbolt
=
use_graphbolt
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
(
(
...
@@ -226,6 +312,7 @@ class SamplingRequest(Request):
...
@@ -226,6 +312,7 @@ class SamplingRequest(Request):
self
.
prob
,
self
.
prob
,
self
.
replace
,
self
.
replace
,
self
.
fan_out
,
self
.
fan_out
,
self
.
use_graphbolt
,
)
=
state
)
=
state
def
__getstate__
(
self
):
def
__getstate__
(
self
):
...
@@ -235,6 +322,7 @@ class SamplingRequest(Request):
...
@@ -235,6 +322,7 @@ class SamplingRequest(Request):
self
.
prob
,
self
.
prob
,
self
.
replace
,
self
.
replace
,
self
.
fan_out
,
self
.
fan_out
,
self
.
use_graphbolt
,
)
)
def
process_request
(
self
,
server_state
):
def
process_request
(
self
,
server_state
):
...
@@ -245,6 +333,16 @@ class SamplingRequest(Request):
...
@@ -245,6 +333,16 @@ class SamplingRequest(Request):
prob
=
[
kv_store
.
data_store
[
self
.
prob
]]
prob
=
[
kv_store
.
data_store
[
self
.
prob
]]
else
:
else
:
prob
=
None
prob
=
None
if
self
.
use_graphbolt
:
global_src
,
global_dst
,
etype_ids
=
_sample_neighbors_graphbolt
(
local_g
,
partition_book
,
self
.
seed_nodes
,
self
.
fan_out
,
prob
,
self
.
replace
,
)
return
SubgraphResponse
(
global_src
,
global_dst
,
etype_ids
)
global_src
,
global_dst
,
global_eids
=
_sample_neighbors
(
global_src
,
global_dst
,
global_eids
=
_sample_neighbors
(
local_g
,
local_g
,
partition_book
,
partition_book
,
...
@@ -449,13 +547,14 @@ def merge_graphs(res_list, num_nodes):
...
@@ -449,13 +547,14 @@ def merge_graphs(res_list, num_nodes):
eids
.
append
(
res
.
global_eids
)
eids
.
append
(
res
.
global_eids
)
src_tensor
=
F
.
cat
(
srcs
,
0
)
src_tensor
=
F
.
cat
(
srcs
,
0
)
dst_tensor
=
F
.
cat
(
dsts
,
0
)
dst_tensor
=
F
.
cat
(
dsts
,
0
)
eid_tensor
=
F
.
cat
(
eids
,
0
)
eid_tensor
=
None
if
eids
[
0
]
is
None
else
F
.
cat
(
eids
,
0
)
else
:
else
:
src_tensor
=
res_list
[
0
].
global_src
src_tensor
=
res_list
[
0
].
global_src
dst_tensor
=
res_list
[
0
].
global_dst
dst_tensor
=
res_list
[
0
].
global_dst
eid_tensor
=
res_list
[
0
].
global_eids
eid_tensor
=
res_list
[
0
].
global_eids
g
=
graph
((
src_tensor
,
dst_tensor
),
num_nodes
=
num_nodes
)
g
=
graph
((
src_tensor
,
dst_tensor
),
num_nodes
=
num_nodes
)
g
.
edata
[
EID
]
=
eid_tensor
if
eid_tensor
is
not
None
:
g
.
edata
[
EID
]
=
eid_tensor
return
g
return
g
...
@@ -491,7 +590,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
...
@@ -491,7 +590,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
"""
"""
req_list
=
[]
req_list
=
[]
partition_book
=
g
.
get_partition_book
()
partition_book
=
g
.
get_partition_book
()
nodes
=
toindex
(
nodes
).
tousertensor
()
if
not
isinstance
(
nodes
,
torch
.
Tensor
):
nodes
=
toindex
(
nodes
).
tousertensor
()
partition_id
=
partition_book
.
nid2partid
(
nodes
)
partition_id
=
partition_book
.
nid2partid
(
nodes
)
local_nids
=
None
local_nids
=
None
for
pid
in
range
(
partition_book
.
num_partitions
()):
for
pid
in
range
(
partition_book
.
num_partitions
()):
...
@@ -721,7 +821,15 @@ def sample_etype_neighbors(
...
@@ -721,7 +821,15 @@ def sample_etype_neighbors(
return
frontier
return
frontier
def
sample_neighbors
(
g
,
nodes
,
fanout
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
):
def
sample_neighbors
(
g
,
nodes
,
fanout
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
,
use_graphbolt
=
False
,
):
"""Sample from the neighbors of the given nodes from a distributed graph.
"""Sample from the neighbors of the given nodes from a distributed graph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
...
@@ -764,6 +872,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
...
@@ -764,6 +872,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
For sampling without replacement, if fanout > the number of neighbors, all the
For sampling without replacement, if fanout > the number of neighbors, all the
neighbors are sampled. If fanout == -1, all neighbors are collected.
neighbors are sampled. If fanout == -1, all neighbors are collected.
use_graphbolt : bool, optional
Whether to use GraphBolt for sampling.
Returns
Returns
-------
-------
...
@@ -795,12 +905,26 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
...
@@ -795,12 +905,26 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
else
:
else
:
_prob
=
None
_prob
=
None
return
SamplingRequest
(
return
SamplingRequest
(
node_ids
,
fanout
,
edge_dir
=
edge_dir
,
prob
=
_prob
,
replace
=
replace
node_ids
,
fanout
,
edge_dir
=
edge_dir
,
prob
=
_prob
,
replace
=
replace
,
use_graphbolt
=
use_graphbolt
,
)
)
def
local_access
(
local_g
,
partition_book
,
local_nids
):
def
local_access
(
local_g
,
partition_book
,
local_nids
):
# See NOTE 1
# See NOTE 1
_prob
=
[
g
.
edata
[
prob
].
local_partition
]
if
prob
is
not
None
else
None
_prob
=
[
g
.
edata
[
prob
].
local_partition
]
if
prob
is
not
None
else
None
if
use_graphbolt
:
return
_sample_neighbors_graphbolt
(
local_g
,
partition_book
,
local_nids
,
fanout
,
prob
=
_prob
,
replace
=
replace
,
)
return
_sample_neighbors
(
return
_sample_neighbors
(
local_g
,
local_g
,
partition_book
,
partition_book
,
...
...
tests/distributed/test_distributed_sampling.py
View file @
6459a688
...
@@ -31,6 +31,7 @@ def start_server(
...
@@ -31,6 +31,7 @@ def start_server(
disable_shared_mem
,
disable_shared_mem
,
graph_name
,
graph_name
,
graph_format
=
[
"csc"
,
"coo"
],
graph_format
=
[
"csc"
,
"coo"
],
use_graphbolt
=
False
,
):
):
g
=
DistGraphServer
(
g
=
DistGraphServer
(
rank
,
rank
,
...
@@ -40,6 +41,7 @@ def start_server(
...
@@ -40,6 +41,7 @@ def start_server(
tmpdir
/
(
graph_name
+
".json"
),
tmpdir
/
(
graph_name
+
".json"
),
disable_shared_mem
=
disable_shared_mem
,
disable_shared_mem
=
disable_shared_mem
,
graph_format
=
graph_format
,
graph_format
=
graph_format
,
use_graphbolt
=
use_graphbolt
,
)
)
g
.
start
()
g
.
start
()
...
@@ -72,6 +74,7 @@ def start_sample_client_shuffle(
...
@@ -72,6 +74,7 @@ def start_sample_client_shuffle(
group_id
,
group_id
,
orig_nid
,
orig_nid
,
orig_eid
,
orig_eid
,
use_graphbolt
=
False
,
):
):
os
.
environ
[
"DGL_GROUP_ID"
]
=
str
(
group_id
)
os
.
environ
[
"DGL_GROUP_ID"
]
=
str
(
group_id
)
gpb
=
None
gpb
=
None
...
@@ -80,17 +83,26 @@ def start_sample_client_shuffle(
...
@@ -80,17 +83,26 @@ def start_sample_client_shuffle(
tmpdir
/
"test_sampling.json"
,
rank
tmpdir
/
"test_sampling.json"
,
rank
)
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
dist_graph
=
DistGraph
(
sampled_graph
=
sample_neighbors
(
dist_graph
,
[
0
,
10
,
99
,
66
,
1024
,
2008
],
3
)
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
sampled_graph
=
sample_neighbors
(
dist_graph
,
[
0
,
10
,
99
,
66
,
1024
,
2008
],
3
,
use_graphbolt
=
use_graphbolt
)
src
,
dst
=
sampled_graph
.
edges
()
src
,
dst
=
sampled_graph
.
edges
()
src
=
orig_nid
[
src
]
src
=
orig_nid
[
src
]
dst
=
orig_nid
[
dst
]
dst
=
orig_nid
[
dst
]
assert
sampled_graph
.
num_nodes
()
==
g
.
num_nodes
()
assert
sampled_graph
.
num_nodes
()
==
g
.
num_nodes
()
assert
np
.
all
(
F
.
asnumpy
(
g
.
has_edges_between
(
src
,
dst
)))
assert
np
.
all
(
F
.
asnumpy
(
g
.
has_edges_between
(
src
,
dst
)))
eids
=
g
.
edge_ids
(
src
,
dst
)
if
use_graphbolt
:
eids1
=
orig_eid
[
sampled_graph
.
edata
[
dgl
.
EID
]]
assert
(
assert
np
.
array_equal
(
F
.
asnumpy
(
eids1
),
F
.
asnumpy
(
eids
))
dgl
.
EID
not
in
sampled_graph
.
edata
),
"EID should not be in sampled graph if use_graphbolt=True."
else
:
eids
=
g
.
edge_ids
(
src
,
dst
)
eids1
=
orig_eid
[
sampled_graph
.
edata
[
dgl
.
EID
]]
assert
np
.
array_equal
(
F
.
asnumpy
(
eids1
),
F
.
asnumpy
(
eids
))
def
start_find_edges_client
(
rank
,
tmpdir
,
disable_shared_mem
,
eids
,
etype
=
None
):
def
start_find_edges_client
(
rank
,
tmpdir
,
disable_shared_mem
,
eids
,
etype
=
None
):
...
@@ -378,7 +390,9 @@ def test_rpc_sampling():
...
@@ -378,7 +390,9 @@ def test_rpc_sampling():
check_rpc_sampling
(
Path
(
tmpdirname
),
1
)
check_rpc_sampling
(
Path
(
tmpdirname
),
1
)
def
check_rpc_sampling_shuffle
(
tmpdir
,
num_server
,
num_groups
=
1
):
def
check_rpc_sampling_shuffle
(
tmpdir
,
num_server
,
num_groups
=
1
,
use_graphbolt
=
False
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
g
=
CitationGraphDataset
(
"cora"
)[
0
]
g
=
CitationGraphDataset
(
"cora"
)[
0
]
...
@@ -393,6 +407,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
...
@@ -393,6 +407,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_hops
=
num_hops
,
num_hops
=
num_hops
,
part_method
=
"metis"
,
part_method
=
"metis"
,
return_mapping
=
True
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
)
)
pserver_list
=
[]
pserver_list
=
[]
...
@@ -406,6 +421,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
...
@@ -406,6 +421,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_server
>
1
,
num_server
>
1
,
"test_sampling"
,
"test_sampling"
,
[
"csc"
,
"coo"
],
[
"csc"
,
"coo"
],
use_graphbolt
,
),
),
)
)
p
.
start
()
p
.
start
()
...
@@ -427,6 +443,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
...
@@ -427,6 +443,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
group_id
,
group_id
,
orig_nids
,
orig_nids
,
orig_eids
,
orig_eids
,
use_graphbolt
,
),
),
)
)
p
.
start
()
p
.
start
()
...
@@ -1012,6 +1029,9 @@ def test_rpc_sampling_shuffle(num_server):
...
@@ -1012,6 +1029,9 @@ def test_rpc_sampling_shuffle(num_server):
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
True
)
check_rpc_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
# [TODO][Rhett] Tests for multiple groups may fail sometimes and
# [TODO][Rhett] Tests for multiple groups may fail sometimes and
# root cause is unknown. Let's disable them for now.
# root cause is unknown. Let's disable them for now.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment