Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f3af2a9f
Unverified
Commit
f3af2a9f
authored
Feb 05, 2024
by
Rhett Ying
Committed by
GitHub
Feb 05, 2024
Browse files
[DistGB] return eids together with etype_ids in sampling (#7084)
parent
346197c4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
171 additions
and
77 deletions
+171
-77
python/dgl/distributed/graph_services.py
python/dgl/distributed/graph_services.py
+107
-51
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+64
-26
No files found.
python/dgl/distributed/graph_services.py
View file @
f3af2a9f
...
...
@@ -6,7 +6,7 @@ import numpy as np
import
torch
from
..
import
backend
as
F
,
graphbolt
as
gb
from
..base
import
EID
,
NID
from
..base
import
EID
,
ETYPE
,
NID
from
..convert
import
graph
,
heterograph
from
..sampling
import
(
sample_etype_neighbors
as
local_sample_etype_neighbors
,
...
...
@@ -40,16 +40,29 @@ ETYPE_SAMPLING_SERVICE_ID = 6662
class
SubgraphResponse
(
Response
):
"""The response for sampling and in_subgraph"""
def
__init__
(
self
,
global_src
,
global_dst
,
global_eids
):
def
__init__
(
self
,
global_src
,
global_dst
,
*
,
global_eids
=
None
,
etype_ids
=
None
):
self
.
global_src
=
global_src
self
.
global_dst
=
global_dst
self
.
global_eids
=
global_eids
self
.
etype_ids
=
etype_ids
def
__setstate__
(
self
,
state
):
self
.
global_src
,
self
.
global_dst
,
self
.
global_eids
=
state
(
self
.
global_src
,
self
.
global_dst
,
self
.
global_eids
,
self
.
etype_ids
,
)
=
state
def
__getstate__
(
self
):
return
self
.
global_src
,
self
.
global_dst
,
self
.
global_eids
return
(
self
.
global_src
,
self
.
global_dst
,
self
.
global_eids
,
self
.
etype_ids
,
)
class
FindEdgeResponse
(
Response
):
...
...
@@ -68,7 +81,7 @@ class FindEdgeResponse(Response):
def
_sample_neighbors_graphbolt
(
g
,
gpb
,
nodes
,
fanout
,
prob
=
None
,
replace
=
False
g
,
gpb
,
nodes
,
fanout
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
):
"""Sample from local partition via graphbolt.
...
...
@@ -77,8 +90,6 @@ def _sample_neighbors_graphbolt(
space again. The sampled results are stored in three vectors that store
source nodes, destination nodes, etype IDs and edge IDs.
[Rui][TODO] edge IDs are not returned as not supported yet.
Parameters
----------
g : FusedCSCSamplingGraph
...
...
@@ -89,6 +100,8 @@ def _sample_neighbors_graphbolt(
The nodes to sample neighbors from.
fanout : tensor or int
The number of edges to be sampled for each node.
edge_dir : str, optional
Determines whether to sample inbound or outbound edges.
prob : tensor, optional
The probability associated with each neighboring edge of a node.
replace : bool, optional
...
...
@@ -100,11 +113,15 @@ def _sample_neighbors_graphbolt(
The source node ID array.
tensor
The destination node ID array.
tensor
The edge type ID array.
tensor
The edge ID array.
tensor
The edge type ID array.
"""
assert
(
edge_dir
==
"in"
),
f
"GraphBolt only supports inbound edge sampling but got
{
edge_dir
}
."
# 1. Map global node IDs to local node IDs.
nodes
=
gpb
.
nid2localnid
(
nodes
,
gpb
.
partid
)
...
...
@@ -139,11 +156,20 @@ def _sample_neighbors_graphbolt(
global_src
=
global_nid_mapping
[
local_src
]
global_dst
=
global_nid_mapping
[
local_dst
]
return
global_src
,
global_dst
,
subgraph
.
type_per_edge
# [Rui][TODO] edge IDs are not supported yet.
return
LocalSampledGraph
(
global_src
,
global_dst
,
None
,
subgraph
.
type_per_edge
)
def
_sample_neighbors
(
local_g
,
partition_book
,
seed_nodes
,
fan_out
,
edge_dir
,
prob
,
replace
def
_sample_neighbors_dgl
(
local_g
,
partition_book
,
seed_nodes
,
fan_out
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
,
):
"""Sample from local partition.
...
...
@@ -170,7 +196,38 @@ def _sample_neighbors(
global_nid_mapping
,
src
),
F
.
gather_row
(
global_nid_mapping
,
dst
)
global_eids
=
F
.
gather_row
(
local_g
.
edata
[
EID
],
sampled_graph
.
edata
[
EID
])
return
global_src
,
global_dst
,
global_eids
return
LocalSampledGraph
(
global_src
,
global_dst
,
global_eids
)
def
_sample_neighbors
(
use_graphbolt
,
*
args
,
**
kwargs
):
"""Wrapper for sampling neighbors.
The actual sampling function depends on whether to use GraphBolt.
Parameters
----------
use_graphbolt : bool
Whether to use GraphBolt for sampling.
args : list
The arguments for the sampling function.
kwargs : dict
The keyword arguments for the sampling function.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge ID array.
tensor
The edge type ID array.
"""
func
=
(
_sample_neighbors_graphbolt
if
use_graphbolt
else
_sample_neighbors_dgl
)
return
func
(
*
args
,
**
kwargs
)
def
_sample_etype_neighbors
(
...
...
@@ -211,7 +268,7 @@ def _sample_etype_neighbors(
global_nid_mapping
,
src
),
F
.
gather_row
(
global_nid_mapping
,
dst
)
global_eids
=
F
.
gather_row
(
local_g
.
edata
[
EID
],
sampled_graph
.
edata
[
EID
])
return
global_src
,
global_dst
,
global_eids
return
LocalSampledGraph
(
global_src
,
global_dst
,
global_eids
)
def
_find_edges
(
local_g
,
partition_book
,
seed_edges
):
...
...
@@ -257,7 +314,7 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
src
,
dst
=
sampled_graph
.
edges
()
global_src
,
global_dst
=
global_nid_mapping
[
src
],
global_nid_mapping
[
dst
]
global_eids
=
F
.
gather_row
(
local_g
.
edata
[
EID
],
sampled_graph
.
edata
[
EID
])
return
global_src
,
global_dst
,
global_eids
return
LocalSampledGraph
(
global_src
,
global_dst
,
global_eids
)
# --- NOTE 1 ---
...
...
@@ -333,26 +390,22 @@ class SamplingRequest(Request):
prob
=
[
kv_store
.
data_store
[
self
.
prob
]]
else
:
prob
=
None
if
self
.
use_graphbolt
:
global_src
,
global_dst
,
etype_ids
=
_sample_neighbors_graphbolt
(
local_g
,
partition_book
,
self
.
seed_nodes
,
self
.
fan_out
,
prob
,
self
.
replace
,
)
return
SubgraphResponse
(
global_src
,
global_dst
,
etype_ids
)
global_src
,
global_dst
,
global_eids
=
_sample_neighbors
(
res
=
_sample_neighbors
(
self
.
use_graphbolt
,
local_g
,
partition_book
,
self
.
seed_nodes
,
self
.
fan_out
,
self
.
edge_dir
,
prob
,
self
.
replace
,
edge_dir
=
self
.
edge_dir
,
prob
=
prob
,
replace
=
self
.
replace
,
)
return
SubgraphResponse
(
res
.
global_src
,
res
.
global_dst
,
global_eids
=
res
.
global_eids
,
etype_ids
=
res
.
etype_ids
,
)
return
SubgraphResponse
(
global_src
,
global_dst
,
global_eids
)
class
SamplingRequestEtype
(
Request
):
...
...
@@ -407,7 +460,7 @@ class SamplingRequestEtype(Request):
]
else
:
probs
=
None
global_src
,
global_dst
,
global_eid
s
=
_sample_etype_neighbors
(
re
s
=
_sample_etype_neighbors
(
local_g
,
partition_book
,
self
.
seed_nodes
,
...
...
@@ -418,7 +471,12 @@ class SamplingRequestEtype(Request):
self
.
replace
,
self
.
etype_sorted
,
)
return
SubgraphResponse
(
global_src
,
global_dst
,
global_eids
)
return
SubgraphResponse
(
res
.
global_src
,
res
.
global_dst
,
global_eids
=
res
.
global_eids
,
etype_ids
=
res
.
etype_ids
,
)
class
EdgesRequest
(
Request
):
...
...
@@ -532,7 +590,7 @@ class InSubgraphRequest(Request):
global_src
,
global_dst
,
global_eids
=
_in_subgraph
(
local_g
,
partition_book
,
self
.
seed_nodes
)
return
SubgraphResponse
(
global_src
,
global_dst
,
global_eids
)
return
SubgraphResponse
(
global_src
,
global_dst
,
global_eids
=
global_eids
)
def
merge_graphs
(
res_list
,
num_nodes
):
...
...
@@ -541,25 +599,33 @@ def merge_graphs(res_list, num_nodes):
srcs
=
[]
dsts
=
[]
eids
=
[]
etype_ids
=
[]
for
res
in
res_list
:
srcs
.
append
(
res
.
global_src
)
dsts
.
append
(
res
.
global_dst
)
eids
.
append
(
res
.
global_eids
)
etype_ids
.
append
(
res
.
etype_ids
)
src_tensor
=
F
.
cat
(
srcs
,
0
)
dst_tensor
=
F
.
cat
(
dsts
,
0
)
eid_tensor
=
None
if
eids
[
0
]
is
None
else
F
.
cat
(
eids
,
0
)
etype_id_tensor
=
None
if
etype_ids
[
0
]
is
None
else
F
.
cat
(
etype_ids
,
0
)
else
:
src_tensor
=
res_list
[
0
].
global_src
dst_tensor
=
res_list
[
0
].
global_dst
eid_tensor
=
res_list
[
0
].
global_eids
etype_id_tensor
=
res_list
[
0
].
etype_ids
g
=
graph
((
src_tensor
,
dst_tensor
),
num_nodes
=
num_nodes
)
if
eid_tensor
is
not
None
:
g
.
edata
[
EID
]
=
eid_tensor
if
etype_id_tensor
is
not
None
:
g
.
edata
[
ETYPE
]
=
etype_id_tensor
return
g
LocalSampledGraph
=
namedtuple
(
"LocalSampledGraph"
,
"global_src global_dst global_eids"
LocalSampledGraph
=
namedtuple
(
# pylint: disable=unexpected-keyword-arg
"LocalSampledGraph"
,
"global_src global_dst global_eids etype_ids"
,
defaults
=
(
None
,
None
,
None
,
None
),
)
...
...
@@ -615,10 +681,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
# sample neighbors for the nodes in the local partition.
res_list
=
[]
if
local_nids
is
not
None
:
src
,
dst
,
eids
=
local_access
(
g
.
local_partition
,
partition_book
,
local_nids
)
res_list
.
append
(
LocalSampledGraph
(
src
,
dst
,
eids
))
res
=
local_access
(
g
.
local_partition
,
partition_book
,
local_nids
)
res_list
.
append
(
res
)
# receive responses from remote machines.
if
msgseq2pos
is
not
None
:
...
...
@@ -916,23 +980,15 @@ def sample_neighbors(
def
local_access
(
local_g
,
partition_book
,
local_nids
):
# See NOTE 1
_prob
=
[
g
.
edata
[
prob
].
local_partition
]
if
prob
is
not
None
else
None
if
use_graphbolt
:
return
_sample_neighbors_graphbolt
(
local_g
,
partition_book
,
local_nids
,
fanout
,
prob
=
_prob
,
replace
=
replace
,
)
return
_sample_neighbors
(
use_graphbolt
,
local_g
,
partition_book
,
local_nids
,
fanout
,
edge_dir
,
_prob
,
replace
,
edge_dir
=
edge_dir
,
prob
=
_prob
,
replace
=
replace
,
)
frontier
=
_distributed_access
(
g
,
nodes
,
issue_remote_req
,
local_access
)
...
...
tests/distributed/test_distributed_sampling.py
View file @
f3af2a9f
import
multiprocessing
as
mp
import
os
import
random
import
sys
import
tempfile
import
time
import
traceback
import
unittest
...
...
@@ -1013,47 +1013,85 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
assert
np
.
all
(
F
.
asnumpy
(
orig_dst1
)
==
orig_dst
)
# Wait non shared memory graph store
@
unittest
.
skipIf
(
os
.
name
==
"nt"
,
reason
=
"Do not support windows yet"
)
@
unittest
.
skipIf
(
dgl
.
backend
.
backend_name
==
"tensorflow"
,
reason
=
"Not support tensorflow for now"
,
)
@
unittest
.
skipIf
(
dgl
.
backend
.
backend_name
==
"mxnet"
,
reason
=
"Turn off Mxnet support"
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_sampling_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
def
test_rpc_sampling_shuffle
(
num_server
,
use_graphbolt
):
reset_envs
()
import
tempfile
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
True
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
)
check_rpc_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
# [TODO][Rhett] Tests for multiple groups may fail sometimes and
# root cause is unknown. Let's disable them for now.
# check_rpc_sampling_shuffle(Path(tmpdirname), num_server, num_groups=2)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_hetero_sampling_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_hetero_sampling_empty_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_sampling_empty_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_hetero_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_hetero_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
[
"csc"
]
)
check_rpc_hetero_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
[
"csr"
]
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"graph_formats"
,
[
None
,
[
"csc"
],
[
"csr"
],
[
"csc"
,
"coo"
]]
)
def
test_rpc_hetero_etype_sampling_shuffle
(
num_server
,
graph_formats
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
[
"csc"
,
"coo"
]
Path
(
tmpdirname
),
num_server
,
graph_formats
=
graph_formats
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_hetero_etype_sampling_empty_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_etype_sampling_empty_shuffle
(
Path
(
tmpdirname
),
num_server
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_sampling_empty_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_sampling_empty
(
Path
(
tmpdirname
),
num_server
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_sampling_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_etype_sampling_empty_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_etype_sampling_empty
(
Path
(
tmpdirname
),
num_server
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_etype_sampling_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment