Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6735a3ae
Unverified
Commit
6735a3ae
authored
Feb 09, 2024
by
Rhett Ying
Committed by
GitHub
Feb 09, 2024
Browse files
[DistGB] enable sample etype neighbors on heterograph (#7095)
parent
3ebdee77
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
234 additions
and
50 deletions
+234
-50
python/dgl/distributed/graph_services.py
python/dgl/distributed/graph_services.py
+63
-16
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+171
-34
No files found.
python/dgl/distributed/graph_services.py
View file @
6735a3ae
...
@@ -143,8 +143,6 @@ def _sample_neighbors_graphbolt(
...
@@ -143,8 +143,6 @@ def _sample_neighbors_graphbolt(
if
isinstance
(
fanout
,
int
):
if
isinstance
(
fanout
,
int
):
fanout
=
torch
.
LongTensor
([
fanout
])
fanout
=
torch
.
LongTensor
([
fanout
])
assert
isinstance
(
fanout
,
torch
.
Tensor
),
"Expect a tensor of fanout."
assert
isinstance
(
fanout
,
torch
.
Tensor
),
"Expect a tensor of fanout."
# [Rui][TODO] Support multiple fanouts.
assert
fanout
.
numel
()
==
1
,
"Expect a single fanout."
return_eids
=
g
.
edge_attributes
is
not
None
and
EID
in
g
.
edge_attributes
return_eids
=
g
.
edge_attributes
is
not
None
and
EID
in
g
.
edge_attributes
subgraph
=
g
.
_sample_neighbors
(
nodes
,
fanout
,
return_eids
=
return_eids
)
subgraph
=
g
.
_sample_neighbors
(
nodes
,
fanout
,
return_eids
=
return_eids
)
...
@@ -237,15 +235,15 @@ def _sample_neighbors(use_graphbolt, *args, **kwargs):
...
@@ -237,15 +235,15 @@ def _sample_neighbors(use_graphbolt, *args, **kwargs):
return
func
(
*
args
,
**
kwargs
)
return
func
(
*
args
,
**
kwargs
)
def
_sample_etype_neighbors
(
def
_sample_etype_neighbors
_dgl
(
local_g
,
local_g
,
partition_book
,
partition_book
,
seed_nodes
,
seed_nodes
,
etype_offset
,
fan_out
,
fan_out
,
edge_dir
,
edge_dir
=
"in"
,
prob
,
prob
=
None
,
replace
,
replace
=
False
,
etype_offset
=
None
,
etype_sorted
=
False
,
etype_sorted
=
False
,
):
):
"""Sample from local partition.
"""Sample from local partition.
...
@@ -255,6 +253,8 @@ def _sample_etype_neighbors(
...
@@ -255,6 +253,8 @@ def _sample_etype_neighbors(
The sampled results are stored in three vectors that store source nodes, destination nodes
The sampled results are stored in three vectors that store source nodes, destination nodes
and edge IDs.
and edge IDs.
"""
"""
assert
etype_offset
is
not
None
,
"The etype offset is not provided."
local_ids
=
partition_book
.
nid2localnid
(
seed_nodes
,
partition_book
.
partid
)
local_ids
=
partition_book
.
nid2localnid
(
seed_nodes
,
partition_book
.
partid
)
local_ids
=
F
.
astype
(
local_ids
,
local_g
.
idtype
)
local_ids
=
F
.
astype
(
local_ids
,
local_g
.
idtype
)
...
@@ -278,6 +278,43 @@ def _sample_etype_neighbors(
...
@@ -278,6 +278,43 @@ def _sample_etype_neighbors(
return
LocalSampledGraph
(
global_src
,
global_dst
,
global_eids
)
return
LocalSampledGraph
(
global_src
,
global_dst
,
global_eids
)
def
_sample_etype_neighbors
(
use_graphbolt
,
*
args
,
**
kwargs
):
"""Wrapper for sampling etype neighbors.
The actual sampling function depends on whether to use GraphBolt.
Parameters
----------
use_graphbolt : bool
Whether to use GraphBolt for sampling.
args : list
The arguments for the sampling function.
kwargs : dict
The keyword arguments for the sampling function.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge ID array.
tensor
The edge type ID array.
"""
func
=
(
_sample_neighbors_graphbolt
if
use_graphbolt
else
_sample_etype_neighbors_dgl
)
if
use_graphbolt
:
# GraphBolt does not require `etype_offset` and `etype_sorted`.
kwargs
.
pop
(
"etype_offset"
,
None
)
kwargs
.
pop
(
"etype_sorted"
,
None
)
return
func
(
*
args
,
**
kwargs
)
def
_find_edges
(
local_g
,
partition_book
,
seed_edges
):
def
_find_edges
(
local_g
,
partition_book
,
seed_edges
):
"""Given an edge ID array, return the source
"""Given an edge ID array, return the source
and destination node ID array ``s`` and ``d`` in the local partition.
and destination node ID array ``s`` and ``d`` in the local partition.
...
@@ -426,6 +463,7 @@ class SamplingRequestEtype(Request):
...
@@ -426,6 +463,7 @@ class SamplingRequestEtype(Request):
prob
=
None
,
prob
=
None
,
replace
=
False
,
replace
=
False
,
etype_sorted
=
True
,
etype_sorted
=
True
,
use_graphbolt
=
False
,
):
):
self
.
seed_nodes
=
nodes
self
.
seed_nodes
=
nodes
self
.
edge_dir
=
edge_dir
self
.
edge_dir
=
edge_dir
...
@@ -433,6 +471,7 @@ class SamplingRequestEtype(Request):
...
@@ -433,6 +471,7 @@ class SamplingRequestEtype(Request):
self
.
replace
=
replace
self
.
replace
=
replace
self
.
fan_out
=
fan_out
self
.
fan_out
=
fan_out
self
.
etype_sorted
=
etype_sorted
self
.
etype_sorted
=
etype_sorted
self
.
use_graphbolt
=
use_graphbolt
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
(
(
...
@@ -442,6 +481,7 @@ class SamplingRequestEtype(Request):
...
@@ -442,6 +481,7 @@ class SamplingRequestEtype(Request):
self
.
replace
,
self
.
replace
,
self
.
fan_out
,
self
.
fan_out
,
self
.
etype_sorted
,
self
.
etype_sorted
,
self
.
use_graphbolt
,
)
=
state
)
=
state
def
__getstate__
(
self
):
def
__getstate__
(
self
):
...
@@ -452,6 +492,7 @@ class SamplingRequestEtype(Request):
...
@@ -452,6 +492,7 @@ class SamplingRequestEtype(Request):
self
.
replace
,
self
.
replace
,
self
.
fan_out
,
self
.
fan_out
,
self
.
etype_sorted
,
self
.
etype_sorted
,
self
.
use_graphbolt
,
)
)
def
process_request
(
self
,
server_state
):
def
process_request
(
self
,
server_state
):
...
@@ -468,15 +509,16 @@ class SamplingRequestEtype(Request):
...
@@ -468,15 +509,16 @@ class SamplingRequestEtype(Request):
else
:
else
:
probs
=
None
probs
=
None
res
=
_sample_etype_neighbors
(
res
=
_sample_etype_neighbors
(
self
.
use_graphbolt
,
local_g
,
local_g
,
partition_book
,
partition_book
,
self
.
seed_nodes
,
self
.
seed_nodes
,
etype_offset
,
self
.
fan_out
,
self
.
fan_out
,
self
.
edge_dir
,
edge_dir
=
self
.
edge_dir
,
probs
,
prob
=
probs
,
self
.
replace
,
replace
=
self
.
replace
,
self
.
etype_sorted
,
etype_offset
=
etype_offset
,
etype_sorted
=
self
.
etype_sorted
,
)
)
return
SubgraphResponse
(
return
SubgraphResponse
(
res
.
global_src
,
res
.
global_src
,
...
@@ -772,6 +814,7 @@ def sample_etype_neighbors(
...
@@ -772,6 +814,7 @@ def sample_etype_neighbors(
prob
=
None
,
prob
=
None
,
replace
=
False
,
replace
=
False
,
etype_sorted
=
True
,
etype_sorted
=
True
,
use_graphbolt
=
False
,
):
):
"""Sample from the neighbors of the given nodes from a distributed graph.
"""Sample from the neighbors of the given nodes from a distributed graph.
...
@@ -825,6 +868,8 @@ def sample_etype_neighbors(
...
@@ -825,6 +868,8 @@ def sample_etype_neighbors(
neighbors are sampled. If fanout == -1, all neighbors are collected.
neighbors are sampled. If fanout == -1, all neighbors are collected.
etype_sorted : bool, optional
etype_sorted : bool, optional
Indicates whether etypes are sorted.
Indicates whether etypes are sorted.
use_graphbolt : bool, optional
Whether to use GraphBolt for sampling.
Returns
Returns
-------
-------
...
@@ -882,6 +927,7 @@ def sample_etype_neighbors(
...
@@ -882,6 +927,7 @@ def sample_etype_neighbors(
prob
=
_prob
,
prob
=
_prob
,
replace
=
replace
,
replace
=
replace
,
etype_sorted
=
etype_sorted
,
etype_sorted
=
etype_sorted
,
use_graphbolt
=
use_graphbolt
,
)
)
def
local_access
(
local_g
,
partition_book
,
local_nids
):
def
local_access
(
local_g
,
partition_book
,
local_nids
):
...
@@ -897,14 +943,15 @@ def sample_etype_neighbors(
...
@@ -897,14 +943,15 @@ def sample_etype_neighbors(
for
etype
in
g
.
canonical_etypes
for
etype
in
g
.
canonical_etypes
]
]
return
_sample_etype_neighbors
(
return
_sample_etype_neighbors
(
use_graphbolt
,
local_g
,
local_g
,
partition_book
,
partition_book
,
local_nids
,
local_nids
,
etype_offset
,
fanout
,
fanout
,
edge_dir
,
edge_dir
=
edge_dir
,
_prob
,
prob
=
_prob
,
replace
,
replace
=
replace
,
etype_offset
=
etype_offset
,
etype_sorted
=
etype_sorted
,
etype_sorted
=
etype_sorted
,
)
)
...
...
tests/distributed/test_distributed_sampling.py
View file @
6735a3ae
...
@@ -508,6 +508,8 @@ def start_hetero_etype_sample_client(
...
@@ -508,6 +508,8 @@ def start_hetero_etype_sample_client(
fanout
=
3
,
fanout
=
3
,
nodes
=
{
"n3"
:
[
0
,
10
,
99
,
66
,
124
,
208
]},
nodes
=
{
"n3"
:
[
0
,
10
,
99
,
66
,
124
,
208
]},
etype_sorted
=
False
,
etype_sorted
=
False
,
use_graphbolt
=
False
,
return_eids
=
False
,
):
):
gpb
=
None
gpb
=
None
if
disable_shared_mem
:
if
disable_shared_mem
:
...
@@ -515,12 +517,14 @@ def start_hetero_etype_sample_client(
...
@@ -515,12 +517,14 @@ def start_hetero_etype_sample_client(
tmpdir
/
"test_sampling.json"
,
rank
tmpdir
/
"test_sampling.json"
,
rank
)
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
if
dist_graph
.
local_partition
is
not
None
:
if
(
not
use_graphbolt
)
and
dist_graph
.
local_partition
is
not
None
:
# Check whether etypes are sorted in dist_graph
# Check whether etypes are sorted in dist_graph
local_g
=
dist_graph
.
local_partition
local_g
=
dist_graph
.
local_partition
local_nids
=
np
.
arange
(
local_g
.
num_nodes
())
local_nids
=
np
.
arange
(
local_g
.
num_nodes
())
...
@@ -533,10 +537,18 @@ def start_hetero_etype_sample_client(
...
@@ -533,10 +537,18 @@ def start_hetero_etype_sample_client(
if
gpb
is
None
:
if
gpb
is
None
:
gpb
=
dist_graph
.
get_partition_book
()
gpb
=
dist_graph
.
get_partition_book
()
try
:
try
:
# Enable santity check in distributed sampling.
os
.
environ
[
"DGL_DIST_DEBUG"
]
=
"1"
sampled_graph
=
sample_etype_neighbors
(
sampled_graph
=
sample_etype_neighbors
(
dist_graph
,
nodes
,
fanout
,
etype_sorted
=
etype_sorted
dist_graph
,
nodes
,
fanout
,
etype_sorted
=
etype_sorted
,
use_graphbolt
=
use_graphbolt
,
)
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
if
sampled_graph
.
num_edges
()
>
0
:
if
not
use_graphbolt
or
return_eids
:
block
.
edata
[
dgl
.
EID
]
=
sampled_graph
.
edata
[
dgl
.
EID
]
block
.
edata
[
dgl
.
EID
]
=
sampled_graph
.
edata
[
dgl
.
EID
]
except
Exception
as
e
:
except
Exception
as
e
:
print
(
traceback
.
format_exc
())
print
(
traceback
.
format_exc
())
...
@@ -689,7 +701,11 @@ def check_rpc_hetero_sampling_empty_shuffle(
...
@@ -689,7 +701,11 @@ def check_rpc_hetero_sampling_empty_shuffle(
def
check_rpc_hetero_etype_sampling_shuffle
(
def
check_rpc_hetero_etype_sampling_shuffle
(
tmpdir
,
num_server
,
graph_formats
=
None
tmpdir
,
num_server
,
graph_formats
=
None
,
use_graphbolt
=
False
,
return_eids
=
False
,
):
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
@@ -706,6 +722,8 @@ def check_rpc_hetero_etype_sampling_shuffle(
...
@@ -706,6 +722,8 @@ def check_rpc_hetero_etype_sampling_shuffle(
part_method
=
"metis"
,
part_method
=
"metis"
,
return_mapping
=
True
,
return_mapping
=
True
,
graph_formats
=
graph_formats
,
graph_formats
=
graph_formats
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
)
pserver_list
=
[]
pserver_list
=
[]
...
@@ -713,7 +731,14 @@ def check_rpc_hetero_etype_sampling_shuffle(
...
@@ -713,7 +731,14 @@ def check_rpc_hetero_etype_sampling_shuffle(
for
i
in
range
(
num_server
):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
p
=
ctx
.
Process
(
target
=
start_server
,
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
]),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
)
p
.
start
()
p
.
start
()
time
.
sleep
(
1
)
time
.
sleep
(
1
)
...
@@ -730,6 +755,8 @@ def check_rpc_hetero_etype_sampling_shuffle(
...
@@ -730,6 +755,8 @@ def check_rpc_hetero_etype_sampling_shuffle(
fanout
,
fanout
,
nodes
=
{
"n3"
:
[
0
,
10
,
99
,
66
,
124
,
208
]},
nodes
=
{
"n3"
:
[
0
,
10
,
99
,
66
,
124
,
208
]},
etype_sorted
=
etype_sorted
,
etype_sorted
=
etype_sorted
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
)
print
(
"Done sampling"
)
print
(
"Done sampling"
)
for
p
in
pserver_list
:
for
p
in
pserver_list
:
...
@@ -747,19 +774,26 @@ def check_rpc_hetero_etype_sampling_shuffle(
...
@@ -747,19 +774,26 @@ def check_rpc_hetero_etype_sampling_shuffle(
# These are global Ids after shuffling.
# These are global Ids after shuffling.
shuffled_src
=
F
.
gather_row
(
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
],
src
)
shuffled_src
=
F
.
gather_row
(
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
],
src
)
shuffled_dst
=
F
.
gather_row
(
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
],
dst
)
shuffled_dst
=
F
.
gather_row
(
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
],
dst
)
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_src
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
src_type
],
shuffled_src
))
orig_src
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
src_type
],
shuffled_src
))
orig_dst
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
dst_type
],
shuffled_dst
))
orig_dst
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
dst_type
],
shuffled_dst
))
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
assert
np
.
all
(
F
.
asnumpy
(
g
.
has_edges_between
(
orig_src
,
orig_dst
,
etype
=
etype
))
)
if
use_graphbolt
and
not
return_eids
:
continue
# Check the node Ids and edge Ids.
# Check the node Ids and edge Ids.
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
orig_src1
,
orig_dst1
=
g
.
find_edges
(
orig_eid
,
etype
=
etype
)
orig_src1
,
orig_dst1
=
g
.
find_edges
(
orig_eid
,
etype
=
etype
)
assert
np
.
all
(
F
.
asnumpy
(
orig_src1
)
==
orig_src
)
assert
np
.
all
(
F
.
asnumpy
(
orig_src1
)
==
orig_src
)
assert
np
.
all
(
F
.
asnumpy
(
orig_dst1
)
==
orig_dst
)
assert
np
.
all
(
F
.
asnumpy
(
orig_dst1
)
==
orig_dst
)
def
check_rpc_hetero_etype_sampling_empty_shuffle
(
tmpdir
,
num_server
):
def
check_rpc_hetero_etype_sampling_empty_shuffle
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
g
=
create_random_hetero
(
dense
=
True
,
empty
=
True
)
g
=
create_random_hetero
(
dense
=
True
,
empty
=
True
)
...
@@ -774,6 +808,8 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
...
@@ -774,6 +808,8 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
num_hops
=
num_hops
,
num_hops
=
num_hops
,
part_method
=
"metis"
,
part_method
=
"metis"
,
return_mapping
=
True
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
)
pserver_list
=
[]
pserver_list
=
[]
...
@@ -781,7 +817,14 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
...
@@ -781,7 +817,14 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
for
i
in
range
(
num_server
):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
p
=
ctx
.
Process
(
target
=
start_server
,
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
)
p
.
start
()
p
.
start
()
time
.
sleep
(
1
)
time
.
sleep
(
1
)
...
@@ -791,7 +834,13 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
...
@@ -791,7 +834,13 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
deg
=
get_degrees
(
g
,
orig_nids
[
"n3"
],
"n3"
)
deg
=
get_degrees
(
g
,
orig_nids
[
"n3"
],
"n3"
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
block
,
gpb
=
start_hetero_etype_sample_client
(
block
,
gpb
=
start_hetero_etype_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
fanout
,
nodes
=
{
"n3"
:
empty_nids
}
0
,
tmpdir
,
num_server
>
1
,
fanout
,
nodes
=
{
"n3"
:
empty_nids
},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
)
print
(
"Done sampling"
)
print
(
"Done sampling"
)
for
p
in
pserver_list
:
for
p
in
pserver_list
:
...
@@ -848,7 +897,13 @@ def start_bipartite_sample_client(
...
@@ -848,7 +897,13 @@ def start_bipartite_sample_client(
def
start_bipartite_etype_sample_client
(
def
start_bipartite_etype_sample_client
(
rank
,
tmpdir
,
disable_shared_mem
,
fanout
=
3
,
nodes
=
{}
rank
,
tmpdir
,
disable_shared_mem
,
fanout
=
3
,
nodes
=
{},
use_graphbolt
=
False
,
return_eids
=
False
,
):
):
gpb
=
None
gpb
=
None
if
disable_shared_mem
:
if
disable_shared_mem
:
...
@@ -856,11 +911,13 @@ def start_bipartite_etype_sample_client(
...
@@ -856,11 +911,13 @@ def start_bipartite_etype_sample_client(
tmpdir
/
"test_sampling.json"
,
rank
tmpdir
/
"test_sampling.json"
,
rank
)
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
if
dist_graph
.
local_partition
is
not
None
:
if
not
use_graphbolt
and
dist_graph
.
local_partition
is
not
None
:
# Check whether etypes are sorted in dist_graph
# Check whether etypes are sorted in dist_graph
local_g
=
dist_graph
.
local_partition
local_g
=
dist_graph
.
local_partition
local_nids
=
np
.
arange
(
local_g
.
num_nodes
())
local_nids
=
np
.
arange
(
local_g
.
num_nodes
())
...
@@ -872,9 +929,12 @@ def start_bipartite_etype_sample_client(
...
@@ -872,9 +929,12 @@ def start_bipartite_etype_sample_client(
if
gpb
is
None
:
if
gpb
is
None
:
gpb
=
dist_graph
.
get_partition_book
()
gpb
=
dist_graph
.
get_partition_book
()
sampled_graph
=
sample_etype_neighbors
(
dist_graph
,
nodes
,
fanout
)
sampled_graph
=
sample_etype_neighbors
(
dist_graph
,
nodes
,
fanout
,
use_graphbolt
=
use_graphbolt
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
if
sampled_graph
.
num_edges
()
>
0
:
if
sampled_graph
.
num_edges
()
>
0
:
if
not
use_graphbolt
or
return_eids
:
block
.
edata
[
dgl
.
EID
]
=
sampled_graph
.
edata
[
dgl
.
EID
]
block
.
edata
[
dgl
.
EID
]
=
sampled_graph
.
edata
[
dgl
.
EID
]
dgl
.
distributed
.
exit_client
()
dgl
.
distributed
.
exit_client
()
return
block
,
gpb
return
block
,
gpb
...
@@ -1019,7 +1079,9 @@ def check_rpc_bipartite_sampling_shuffle(
...
@@ -1019,7 +1079,9 @@ def check_rpc_bipartite_sampling_shuffle(
assert
np
.
all
(
F
.
asnumpy
(
orig_dst1
)
==
orig_dst
)
assert
np
.
all
(
F
.
asnumpy
(
orig_dst1
)
==
orig_dst
)
def
check_rpc_bipartite_etype_sampling_empty
(
tmpdir
,
num_server
):
def
check_rpc_bipartite_etype_sampling_empty
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
"""sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
"""sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
@@ -1035,6 +1097,8 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
...
@@ -1035,6 +1097,8 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
num_hops
=
num_hops
,
num_hops
=
num_hops
,
part_method
=
"metis"
,
part_method
=
"metis"
,
return_mapping
=
True
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
)
pserver_list
=
[]
pserver_list
=
[]
...
@@ -1042,7 +1106,14 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
...
@@ -1042,7 +1106,14 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
for
i
in
range
(
num_server
):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
p
=
ctx
.
Process
(
target
=
start_server
,
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
)
p
.
start
()
p
.
start
()
time
.
sleep
(
1
)
time
.
sleep
(
1
)
...
@@ -1050,8 +1121,13 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
...
@@ -1050,8 +1121,13 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
deg
=
get_degrees
(
g
,
orig_nids
[
"game"
],
"game"
)
deg
=
get_degrees
(
g
,
orig_nids
[
"game"
],
"game"
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
block
,
gpb
=
start_bipartite_etype_sample_client
(
block
,
_
=
start_bipartite_etype_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"game"
:
empty_nids
,
"user"
:
[
1
]}
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"game"
:
empty_nids
,
"user"
:
[
1
]},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
)
print
(
"Done sampling"
)
print
(
"Done sampling"
)
...
@@ -1064,7 +1140,9 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
...
@@ -1064,7 +1140,9 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
assert
len
(
block
.
etypes
)
==
len
(
g
.
etypes
)
assert
len
(
block
.
etypes
)
==
len
(
g
.
etypes
)
def
check_rpc_bipartite_etype_sampling_shuffle
(
tmpdir
,
num_server
):
def
check_rpc_bipartite_etype_sampling_shuffle
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
"""sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
"""sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
@@ -1080,6 +1158,8 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
...
@@ -1080,6 +1158,8 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
num_hops
=
num_hops
,
num_hops
=
num_hops
,
part_method
=
"metis"
,
part_method
=
"metis"
,
return_mapping
=
True
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
)
pserver_list
=
[]
pserver_list
=
[]
...
@@ -1087,7 +1167,14 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
...
@@ -1087,7 +1167,14 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
for
i
in
range
(
num_server
):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
p
=
ctx
.
Process
(
target
=
start_server
,
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
)
p
.
start
()
p
.
start
()
time
.
sleep
(
1
)
time
.
sleep
(
1
)
...
@@ -1097,7 +1184,13 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
...
@@ -1097,7 +1184,13 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
deg
=
get_degrees
(
g
,
orig_nid_map
[
"game"
],
"game"
)
deg
=
get_degrees
(
g
,
orig_nid_map
[
"game"
],
"game"
)
nids
=
F
.
nonzero_1d
(
deg
>
0
)
nids
=
F
.
nonzero_1d
(
deg
>
0
)
block
,
gpb
=
start_bipartite_etype_sample_client
(
block
,
gpb
=
start_bipartite_etype_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
fanout
,
nodes
=
{
"game"
:
nids
,
"user"
:
[
0
]}
0
,
tmpdir
,
num_server
>
1
,
fanout
,
nodes
=
{
"game"
:
nids
,
"user"
:
[
0
]},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
)
print
(
"Done sampling"
)
print
(
"Done sampling"
)
for
p
in
pserver_list
:
for
p
in
pserver_list
:
...
@@ -1110,13 +1203,18 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
...
@@ -1110,13 +1203,18 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
# These are global Ids after shuffling.
# These are global Ids after shuffling.
shuffled_src
=
F
.
gather_row
(
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
],
src
)
shuffled_src
=
F
.
gather_row
(
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
],
src
)
shuffled_dst
=
F
.
gather_row
(
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
],
dst
)
shuffled_dst
=
F
.
gather_row
(
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
],
dst
)
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_src
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
src_type
],
shuffled_src
))
orig_src
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
src_type
],
shuffled_src
))
orig_dst
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
dst_type
],
shuffled_dst
))
orig_dst
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
dst_type
],
shuffled_dst
))
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
assert
np
.
all
(
F
.
asnumpy
(
g
.
has_edges_between
(
orig_src
,
orig_dst
,
etype
=
etype
))
)
if
use_graphbolt
and
not
return_eids
:
continue
# Check the node Ids and edge Ids.
# Check the node Ids and edge Ids.
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
orig_src1
,
orig_dst1
=
g
.
find_edges
(
orig_eid
,
etype
=
etype
)
orig_src1
,
orig_dst1
=
g
.
find_edges
(
orig_eid
,
etype
=
etype
)
assert
np
.
all
(
F
.
asnumpy
(
orig_src1
)
==
orig_src
)
assert
np
.
all
(
F
.
asnumpy
(
orig_src1
)
==
orig_src
)
assert
np
.
all
(
F
.
asnumpy
(
orig_dst1
)
==
orig_dst
)
assert
np
.
all
(
F
.
asnumpy
(
orig_dst1
)
==
orig_dst
)
...
@@ -1173,7 +1271,7 @@ def test_rpc_hetero_sampling_empty_shuffle(
...
@@ -1173,7 +1271,7 @@ def test_rpc_hetero_sampling_empty_shuffle(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"graph_formats"
,
[
None
,
[
"csc"
],
[
"csr"
],
[
"csc"
,
"coo"
]]
"graph_formats"
,
[
None
,
[
"csc"
],
[
"csr"
],
[
"csc"
,
"coo"
]]
)
)
def
test_rpc_hetero_etype_sampling_shuffle
(
num_server
,
graph_formats
):
def
test_rpc_hetero_etype_sampling_shuffle
_dgl
(
num_server
,
graph_formats
):
reset_envs
()
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
...
@@ -1183,12 +1281,33 @@ def test_rpc_hetero_etype_sampling_shuffle(num_server, graph_formats):
...
@@ -1183,12 +1281,33 @@ def test_rpc_hetero_etype_sampling_shuffle(num_server, graph_formats):
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_hetero_etype_sampling_empty_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_hetero_etype_sampling_shuffle_graphbolt
(
num_server
,
return_eids
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
True
,
return_eids
=
return_eids
,
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_hetero_etype_sampling_empty_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_etype_sampling_empty_shuffle
(
check_rpc_hetero_etype_sampling_empty_shuffle
(
Path
(
tmpdirname
),
num_server
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
)
...
@@ -1219,19 +1338,37 @@ def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):
...
@@ -1219,19 +1338,37 @@ def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_etype_sampling_empty_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_bipartite_etype_sampling_empty_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_etype_sampling_empty
(
Path
(
tmpdirname
),
num_server
)
check_rpc_bipartite_etype_sampling_empty
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_etype_sampling_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_bipartite_etype_sampling_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_bipartite_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
def
check_standalone_sampling
(
tmpdir
):
def
check_standalone_sampling
(
tmpdir
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment