Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6735a3ae
Unverified
Commit
6735a3ae
authored
Feb 09, 2024
by
Rhett Ying
Committed by
GitHub
Feb 09, 2024
Browse files
[DistGB] enable sample etype neighbors on heterograph (#7095)
parent
3ebdee77
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
234 additions
and
50 deletions
+234
-50
python/dgl/distributed/graph_services.py
python/dgl/distributed/graph_services.py
+63
-16
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+171
-34
No files found.
python/dgl/distributed/graph_services.py
View file @
6735a3ae
...
...
@@ -143,8 +143,6 @@ def _sample_neighbors_graphbolt(
if
isinstance
(
fanout
,
int
):
fanout
=
torch
.
LongTensor
([
fanout
])
assert
isinstance
(
fanout
,
torch
.
Tensor
),
"Expect a tensor of fanout."
# [Rui][TODO] Support multiple fanouts.
assert
fanout
.
numel
()
==
1
,
"Expect a single fanout."
return_eids
=
g
.
edge_attributes
is
not
None
and
EID
in
g
.
edge_attributes
subgraph
=
g
.
_sample_neighbors
(
nodes
,
fanout
,
return_eids
=
return_eids
)
...
...
@@ -237,15 +235,15 @@ def _sample_neighbors(use_graphbolt, *args, **kwargs):
return
func
(
*
args
,
**
kwargs
)
def
_sample_etype_neighbors
(
def
_sample_etype_neighbors
_dgl
(
local_g
,
partition_book
,
seed_nodes
,
etype_offset
,
fan_out
,
edge_dir
,
prob
,
replace
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
,
etype_offset
=
None
,
etype_sorted
=
False
,
):
"""Sample from local partition.
...
...
@@ -255,6 +253,8 @@ def _sample_etype_neighbors(
The sampled results are stored in three vectors that store source nodes, destination nodes
and edge IDs.
"""
assert
etype_offset
is
not
None
,
"The etype offset is not provided."
local_ids
=
partition_book
.
nid2localnid
(
seed_nodes
,
partition_book
.
partid
)
local_ids
=
F
.
astype
(
local_ids
,
local_g
.
idtype
)
...
...
@@ -278,6 +278,43 @@ def _sample_etype_neighbors(
return
LocalSampledGraph
(
global_src
,
global_dst
,
global_eids
)
def
_sample_etype_neighbors
(
use_graphbolt
,
*
args
,
**
kwargs
):
"""Wrapper for sampling etype neighbors.
The actual sampling function depends on whether to use GraphBolt.
Parameters
----------
use_graphbolt : bool
Whether to use GraphBolt for sampling.
args : list
The arguments for the sampling function.
kwargs : dict
The keyword arguments for the sampling function.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge ID array.
tensor
The edge type ID array.
"""
func
=
(
_sample_neighbors_graphbolt
if
use_graphbolt
else
_sample_etype_neighbors_dgl
)
if
use_graphbolt
:
# GraphBolt does not require `etype_offset` and `etype_sorted`.
kwargs
.
pop
(
"etype_offset"
,
None
)
kwargs
.
pop
(
"etype_sorted"
,
None
)
return
func
(
*
args
,
**
kwargs
)
def
_find_edges
(
local_g
,
partition_book
,
seed_edges
):
"""Given an edge ID array, return the source
and destination node ID array ``s`` and ``d`` in the local partition.
...
...
@@ -426,6 +463,7 @@ class SamplingRequestEtype(Request):
prob
=
None
,
replace
=
False
,
etype_sorted
=
True
,
use_graphbolt
=
False
,
):
self
.
seed_nodes
=
nodes
self
.
edge_dir
=
edge_dir
...
...
@@ -433,6 +471,7 @@ class SamplingRequestEtype(Request):
self
.
replace
=
replace
self
.
fan_out
=
fan_out
self
.
etype_sorted
=
etype_sorted
self
.
use_graphbolt
=
use_graphbolt
def
__setstate__
(
self
,
state
):
(
...
...
@@ -442,6 +481,7 @@ class SamplingRequestEtype(Request):
self
.
replace
,
self
.
fan_out
,
self
.
etype_sorted
,
self
.
use_graphbolt
,
)
=
state
def
__getstate__
(
self
):
...
...
@@ -452,6 +492,7 @@ class SamplingRequestEtype(Request):
self
.
replace
,
self
.
fan_out
,
self
.
etype_sorted
,
self
.
use_graphbolt
,
)
def
process_request
(
self
,
server_state
):
...
...
@@ -468,15 +509,16 @@ class SamplingRequestEtype(Request):
else
:
probs
=
None
res
=
_sample_etype_neighbors
(
self
.
use_graphbolt
,
local_g
,
partition_book
,
self
.
seed_nodes
,
etype_offset
,
self
.
fan_out
,
self
.
edge_dir
,
probs
,
self
.
replace
,
self
.
etype_sorted
,
edge_dir
=
self
.
edge_dir
,
prob
=
probs
,
replace
=
self
.
replace
,
etype_offset
=
etype_offset
,
etype_sorted
=
self
.
etype_sorted
,
)
return
SubgraphResponse
(
res
.
global_src
,
...
...
@@ -772,6 +814,7 @@ def sample_etype_neighbors(
prob
=
None
,
replace
=
False
,
etype_sorted
=
True
,
use_graphbolt
=
False
,
):
"""Sample from the neighbors of the given nodes from a distributed graph.
...
...
@@ -825,6 +868,8 @@ def sample_etype_neighbors(
neighbors are sampled. If fanout == -1, all neighbors are collected.
etype_sorted : bool, optional
Indicates whether etypes are sorted.
use_graphbolt : bool, optional
Whether to use GraphBolt for sampling.
Returns
-------
...
...
@@ -882,6 +927,7 @@ def sample_etype_neighbors(
prob
=
_prob
,
replace
=
replace
,
etype_sorted
=
etype_sorted
,
use_graphbolt
=
use_graphbolt
,
)
def
local_access
(
local_g
,
partition_book
,
local_nids
):
...
...
@@ -897,14 +943,15 @@ def sample_etype_neighbors(
for
etype
in
g
.
canonical_etypes
]
return
_sample_etype_neighbors
(
use_graphbolt
,
local_g
,
partition_book
,
local_nids
,
etype_offset
,
fanout
,
edge_dir
,
_prob
,
replace
,
edge_dir
=
edge_dir
,
prob
=
_prob
,
replace
=
replace
,
etype_offset
=
etype_offset
,
etype_sorted
=
etype_sorted
,
)
...
...
tests/distributed/test_distributed_sampling.py
View file @
6735a3ae
...
...
@@ -508,6 +508,8 @@ def start_hetero_etype_sample_client(
fanout
=
3
,
nodes
=
{
"n3"
:
[
0
,
10
,
99
,
66
,
124
,
208
]},
etype_sorted
=
False
,
use_graphbolt
=
False
,
return_eids
=
False
,
):
gpb
=
None
if
disable_shared_mem
:
...
...
@@ -515,12 +517,14 @@ def start_hetero_etype_sample_client(
tmpdir
/
"test_sampling.json"
,
rank
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
if
dist_graph
.
local_partition
is
not
None
:
if
(
not
use_graphbolt
)
and
dist_graph
.
local_partition
is
not
None
:
# Check whether etypes are sorted in dist_graph
local_g
=
dist_graph
.
local_partition
local_nids
=
np
.
arange
(
local_g
.
num_nodes
())
...
...
@@ -533,10 +537,18 @@ def start_hetero_etype_sample_client(
if
gpb
is
None
:
gpb
=
dist_graph
.
get_partition_book
()
try
:
# Enable santity check in distributed sampling.
os
.
environ
[
"DGL_DIST_DEBUG"
]
=
"1"
sampled_graph
=
sample_etype_neighbors
(
dist_graph
,
nodes
,
fanout
,
etype_sorted
=
etype_sorted
dist_graph
,
nodes
,
fanout
,
etype_sorted
=
etype_sorted
,
use_graphbolt
=
use_graphbolt
,
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
if
sampled_graph
.
num_edges
()
>
0
:
if
not
use_graphbolt
or
return_eids
:
block
.
edata
[
dgl
.
EID
]
=
sampled_graph
.
edata
[
dgl
.
EID
]
except
Exception
as
e
:
print
(
traceback
.
format_exc
())
...
...
@@ -689,7 +701,11 @@ def check_rpc_hetero_sampling_empty_shuffle(
def
check_rpc_hetero_etype_sampling_shuffle
(
tmpdir
,
num_server
,
graph_formats
=
None
tmpdir
,
num_server
,
graph_formats
=
None
,
use_graphbolt
=
False
,
return_eids
=
False
,
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
...
@@ -706,6 +722,8 @@ def check_rpc_hetero_etype_sampling_shuffle(
part_method
=
"metis"
,
return_mapping
=
True
,
graph_formats
=
graph_formats
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
pserver_list
=
[]
...
...
@@ -713,7 +731,14 @@ def check_rpc_hetero_etype_sampling_shuffle(
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
]),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
p
.
start
()
time
.
sleep
(
1
)
...
...
@@ -730,6 +755,8 @@ def check_rpc_hetero_etype_sampling_shuffle(
fanout
,
nodes
=
{
"n3"
:
[
0
,
10
,
99
,
66
,
124
,
208
]},
etype_sorted
=
etype_sorted
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
print
(
"Done sampling"
)
for
p
in
pserver_list
:
...
...
@@ -747,19 +774,26 @@ def check_rpc_hetero_etype_sampling_shuffle(
# These are global Ids after shuffling.
shuffled_src
=
F
.
gather_row
(
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
],
src
)
shuffled_dst
=
F
.
gather_row
(
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
],
dst
)
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_src
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
src_type
],
shuffled_src
))
orig_dst
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
dst_type
],
shuffled_dst
))
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
assert
np
.
all
(
F
.
asnumpy
(
g
.
has_edges_between
(
orig_src
,
orig_dst
,
etype
=
etype
))
)
if
use_graphbolt
and
not
return_eids
:
continue
# Check the node Ids and edge Ids.
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
orig_src1
,
orig_dst1
=
g
.
find_edges
(
orig_eid
,
etype
=
etype
)
assert
np
.
all
(
F
.
asnumpy
(
orig_src1
)
==
orig_src
)
assert
np
.
all
(
F
.
asnumpy
(
orig_dst1
)
==
orig_dst
)
def
check_rpc_hetero_etype_sampling_empty_shuffle
(
tmpdir
,
num_server
):
def
check_rpc_hetero_etype_sampling_empty_shuffle
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
g
=
create_random_hetero
(
dense
=
True
,
empty
=
True
)
...
...
@@ -774,6 +808,8 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
num_hops
=
num_hops
,
part_method
=
"metis"
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
pserver_list
=
[]
...
...
@@ -781,7 +817,14 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
p
.
start
()
time
.
sleep
(
1
)
...
...
@@ -791,7 +834,13 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
deg
=
get_degrees
(
g
,
orig_nids
[
"n3"
],
"n3"
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
block
,
gpb
=
start_hetero_etype_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
fanout
,
nodes
=
{
"n3"
:
empty_nids
}
0
,
tmpdir
,
num_server
>
1
,
fanout
,
nodes
=
{
"n3"
:
empty_nids
},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
print
(
"Done sampling"
)
for
p
in
pserver_list
:
...
...
@@ -848,7 +897,13 @@ def start_bipartite_sample_client(
def
start_bipartite_etype_sample_client
(
rank
,
tmpdir
,
disable_shared_mem
,
fanout
=
3
,
nodes
=
{}
rank
,
tmpdir
,
disable_shared_mem
,
fanout
=
3
,
nodes
=
{},
use_graphbolt
=
False
,
return_eids
=
False
,
):
gpb
=
None
if
disable_shared_mem
:
...
...
@@ -856,11 +911,13 @@ def start_bipartite_etype_sample_client(
tmpdir
/
"test_sampling.json"
,
rank
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
if
dist_graph
.
local_partition
is
not
None
:
if
not
use_graphbolt
and
dist_graph
.
local_partition
is
not
None
:
# Check whether etypes are sorted in dist_graph
local_g
=
dist_graph
.
local_partition
local_nids
=
np
.
arange
(
local_g
.
num_nodes
())
...
...
@@ -872,9 +929,12 @@ def start_bipartite_etype_sample_client(
if
gpb
is
None
:
gpb
=
dist_graph
.
get_partition_book
()
sampled_graph
=
sample_etype_neighbors
(
dist_graph
,
nodes
,
fanout
)
sampled_graph
=
sample_etype_neighbors
(
dist_graph
,
nodes
,
fanout
,
use_graphbolt
=
use_graphbolt
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
if
sampled_graph
.
num_edges
()
>
0
:
if
not
use_graphbolt
or
return_eids
:
block
.
edata
[
dgl
.
EID
]
=
sampled_graph
.
edata
[
dgl
.
EID
]
dgl
.
distributed
.
exit_client
()
return
block
,
gpb
...
...
@@ -1019,7 +1079,9 @@ def check_rpc_bipartite_sampling_shuffle(
assert
np
.
all
(
F
.
asnumpy
(
orig_dst1
)
==
orig_dst
)
def
check_rpc_bipartite_etype_sampling_empty
(
tmpdir
,
num_server
):
def
check_rpc_bipartite_etype_sampling_empty
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
"""sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
...
@@ -1035,6 +1097,8 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
num_hops
=
num_hops
,
part_method
=
"metis"
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
pserver_list
=
[]
...
...
@@ -1042,7 +1106,14 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
p
.
start
()
time
.
sleep
(
1
)
...
...
@@ -1050,8 +1121,13 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
deg
=
get_degrees
(
g
,
orig_nids
[
"game"
],
"game"
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
block
,
gpb
=
start_bipartite_etype_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"game"
:
empty_nids
,
"user"
:
[
1
]}
block
,
_
=
start_bipartite_etype_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"game"
:
empty_nids
,
"user"
:
[
1
]},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
print
(
"Done sampling"
)
...
...
@@ -1064,7 +1140,9 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
assert
len
(
block
.
etypes
)
==
len
(
g
.
etypes
)
def
check_rpc_bipartite_etype_sampling_shuffle
(
tmpdir
,
num_server
):
def
check_rpc_bipartite_etype_sampling_shuffle
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
"""sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
...
@@ -1080,6 +1158,8 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
num_hops
=
num_hops
,
part_method
=
"metis"
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
pserver_list
=
[]
...
...
@@ -1087,7 +1167,14 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
p
.
start
()
time
.
sleep
(
1
)
...
...
@@ -1097,7 +1184,13 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
deg
=
get_degrees
(
g
,
orig_nid_map
[
"game"
],
"game"
)
nids
=
F
.
nonzero_1d
(
deg
>
0
)
block
,
gpb
=
start_bipartite_etype_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
fanout
,
nodes
=
{
"game"
:
nids
,
"user"
:
[
0
]}
0
,
tmpdir
,
num_server
>
1
,
fanout
,
nodes
=
{
"game"
:
nids
,
"user"
:
[
0
]},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
print
(
"Done sampling"
)
for
p
in
pserver_list
:
...
...
@@ -1110,13 +1203,18 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
# These are global Ids after shuffling.
shuffled_src
=
F
.
gather_row
(
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
],
src
)
shuffled_dst
=
F
.
gather_row
(
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
],
dst
)
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_src
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
src_type
],
shuffled_src
))
orig_dst
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
dst_type
],
shuffled_dst
))
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
assert
np
.
all
(
F
.
asnumpy
(
g
.
has_edges_between
(
orig_src
,
orig_dst
,
etype
=
etype
))
)
if
use_graphbolt
and
not
return_eids
:
continue
# Check the node Ids and edge Ids.
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
orig_src1
,
orig_dst1
=
g
.
find_edges
(
orig_eid
,
etype
=
etype
)
assert
np
.
all
(
F
.
asnumpy
(
orig_src1
)
==
orig_src
)
assert
np
.
all
(
F
.
asnumpy
(
orig_dst1
)
==
orig_dst
)
...
...
@@ -1173,7 +1271,7 @@ def test_rpc_hetero_sampling_empty_shuffle(
@
pytest
.
mark
.
parametrize
(
"graph_formats"
,
[
None
,
[
"csc"
],
[
"csr"
],
[
"csc"
,
"coo"
]]
)
def
test_rpc_hetero_etype_sampling_shuffle
(
num_server
,
graph_formats
):
def
test_rpc_hetero_etype_sampling_shuffle
_dgl
(
num_server
,
graph_formats
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
...
...
@@ -1183,12 +1281,33 @@ def test_rpc_hetero_etype_sampling_shuffle(num_server, graph_formats):
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_hetero_etype_sampling_empty_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_hetero_etype_sampling_shuffle_graphbolt
(
num_server
,
return_eids
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
True
,
return_eids
=
return_eids
,
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_hetero_etype_sampling_empty_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_etype_sampling_empty_shuffle
(
Path
(
tmpdirname
),
num_server
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
...
...
@@ -1219,19 +1338,37 @@ def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_etype_sampling_empty_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_bipartite_etype_sampling_empty_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_etype_sampling_empty
(
Path
(
tmpdirname
),
num_server
)
check_rpc_bipartite_etype_sampling_empty
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_etype_sampling_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_bipartite_etype_sampling_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_bipartite_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
def
check_standalone_sampling
(
tmpdir
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment