Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f3af2a9f
Unverified
Commit
f3af2a9f
authored
Feb 05, 2024
by
Rhett Ying
Committed by
GitHub
Feb 05, 2024
Browse files
[DistGB] return eids together with etype_ids in sampling (#7084)
parent
346197c4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
171 additions
and
77 deletions
+171
-77
python/dgl/distributed/graph_services.py
python/dgl/distributed/graph_services.py
+107
-51
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+64
-26
No files found.
python/dgl/distributed/graph_services.py
View file @
f3af2a9f
...
@@ -6,7 +6,7 @@ import numpy as np
...
@@ -6,7 +6,7 @@ import numpy as np
import
torch
import
torch
from
..
import
backend
as
F
,
graphbolt
as
gb
from
..
import
backend
as
F
,
graphbolt
as
gb
from
..base
import
EID
,
NID
from
..base
import
EID
,
ETYPE
,
NID
from
..convert
import
graph
,
heterograph
from
..convert
import
graph
,
heterograph
from
..sampling
import
(
from
..sampling
import
(
sample_etype_neighbors
as
local_sample_etype_neighbors
,
sample_etype_neighbors
as
local_sample_etype_neighbors
,
...
@@ -40,16 +40,29 @@ ETYPE_SAMPLING_SERVICE_ID = 6662
...
@@ -40,16 +40,29 @@ ETYPE_SAMPLING_SERVICE_ID = 6662
class
SubgraphResponse
(
Response
):
class
SubgraphResponse
(
Response
):
"""The response for sampling and in_subgraph"""
"""The response for sampling and in_subgraph"""
def
__init__
(
self
,
global_src
,
global_dst
,
global_eids
):
def
__init__
(
self
,
global_src
,
global_dst
,
*
,
global_eids
=
None
,
etype_ids
=
None
):
self
.
global_src
=
global_src
self
.
global_src
=
global_src
self
.
global_dst
=
global_dst
self
.
global_dst
=
global_dst
self
.
global_eids
=
global_eids
self
.
global_eids
=
global_eids
self
.
etype_ids
=
etype_ids
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
self
.
global_src
,
self
.
global_dst
,
self
.
global_eids
=
state
(
self
.
global_src
,
self
.
global_dst
,
self
.
global_eids
,
self
.
etype_ids
,
)
=
state
def
__getstate__
(
self
):
def
__getstate__
(
self
):
return
self
.
global_src
,
self
.
global_dst
,
self
.
global_eids
return
(
self
.
global_src
,
self
.
global_dst
,
self
.
global_eids
,
self
.
etype_ids
,
)
class
FindEdgeResponse
(
Response
):
class
FindEdgeResponse
(
Response
):
...
@@ -68,7 +81,7 @@ class FindEdgeResponse(Response):
...
@@ -68,7 +81,7 @@ class FindEdgeResponse(Response):
def
_sample_neighbors_graphbolt
(
def
_sample_neighbors_graphbolt
(
g
,
gpb
,
nodes
,
fanout
,
prob
=
None
,
replace
=
False
g
,
gpb
,
nodes
,
fanout
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
):
):
"""Sample from local partition via graphbolt.
"""Sample from local partition via graphbolt.
...
@@ -77,8 +90,6 @@ def _sample_neighbors_graphbolt(
...
@@ -77,8 +90,6 @@ def _sample_neighbors_graphbolt(
space again. The sampled results are stored in three vectors that store
space again. The sampled results are stored in three vectors that store
source nodes, destination nodes, etype IDs and edge IDs.
source nodes, destination nodes, etype IDs and edge IDs.
[Rui][TODO] edge IDs are not returned as not supported yet.
Parameters
Parameters
----------
----------
g : FusedCSCSamplingGraph
g : FusedCSCSamplingGraph
...
@@ -89,6 +100,8 @@ def _sample_neighbors_graphbolt(
...
@@ -89,6 +100,8 @@ def _sample_neighbors_graphbolt(
The nodes to sample neighbors from.
The nodes to sample neighbors from.
fanout : tensor or int
fanout : tensor or int
The number of edges to be sampled for each node.
The number of edges to be sampled for each node.
edge_dir : str, optional
Determines whether to sample inbound or outbound edges.
prob : tensor, optional
prob : tensor, optional
The probability associated with each neighboring edge of a node.
The probability associated with each neighboring edge of a node.
replace : bool, optional
replace : bool, optional
...
@@ -100,11 +113,15 @@ def _sample_neighbors_graphbolt(
...
@@ -100,11 +113,15 @@ def _sample_neighbors_graphbolt(
The source node ID array.
The source node ID array.
tensor
tensor
The destination node ID array.
The destination node ID array.
tensor
The edge type ID array.
tensor
tensor
The edge ID array.
The edge ID array.
tensor
The edge type ID array.
"""
"""
assert
(
edge_dir
==
"in"
),
f
"GraphBolt only supports inbound edge sampling but got
{
edge_dir
}
."
# 1. Map global node IDs to local node IDs.
# 1. Map global node IDs to local node IDs.
nodes
=
gpb
.
nid2localnid
(
nodes
,
gpb
.
partid
)
nodes
=
gpb
.
nid2localnid
(
nodes
,
gpb
.
partid
)
...
@@ -139,11 +156,20 @@ def _sample_neighbors_graphbolt(
...
@@ -139,11 +156,20 @@ def _sample_neighbors_graphbolt(
global_src
=
global_nid_mapping
[
local_src
]
global_src
=
global_nid_mapping
[
local_src
]
global_dst
=
global_nid_mapping
[
local_dst
]
global_dst
=
global_nid_mapping
[
local_dst
]
return
global_src
,
global_dst
,
subgraph
.
type_per_edge
# [Rui][TODO] edge IDs are not supported yet.
return
LocalSampledGraph
(
global_src
,
global_dst
,
None
,
subgraph
.
type_per_edge
)
def
_sample_neighbors
(
def
_sample_neighbors_dgl
(
local_g
,
partition_book
,
seed_nodes
,
fan_out
,
edge_dir
,
prob
,
replace
local_g
,
partition_book
,
seed_nodes
,
fan_out
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
,
):
):
"""Sample from local partition.
"""Sample from local partition.
...
@@ -170,7 +196,38 @@ def _sample_neighbors(
...
@@ -170,7 +196,38 @@ def _sample_neighbors(
global_nid_mapping
,
src
global_nid_mapping
,
src
),
F
.
gather_row
(
global_nid_mapping
,
dst
)
),
F
.
gather_row
(
global_nid_mapping
,
dst
)
global_eids
=
F
.
gather_row
(
local_g
.
edata
[
EID
],
sampled_graph
.
edata
[
EID
])
global_eids
=
F
.
gather_row
(
local_g
.
edata
[
EID
],
sampled_graph
.
edata
[
EID
])
return
global_src
,
global_dst
,
global_eids
return
LocalSampledGraph
(
global_src
,
global_dst
,
global_eids
)
def
_sample_neighbors
(
use_graphbolt
,
*
args
,
**
kwargs
):
"""Wrapper for sampling neighbors.
The actual sampling function depends on whether to use GraphBolt.
Parameters
----------
use_graphbolt : bool
Whether to use GraphBolt for sampling.
args : list
The arguments for the sampling function.
kwargs : dict
The keyword arguments for the sampling function.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge ID array.
tensor
The edge type ID array.
"""
func
=
(
_sample_neighbors_graphbolt
if
use_graphbolt
else
_sample_neighbors_dgl
)
return
func
(
*
args
,
**
kwargs
)
def
_sample_etype_neighbors
(
def
_sample_etype_neighbors
(
...
@@ -211,7 +268,7 @@ def _sample_etype_neighbors(
...
@@ -211,7 +268,7 @@ def _sample_etype_neighbors(
global_nid_mapping
,
src
global_nid_mapping
,
src
),
F
.
gather_row
(
global_nid_mapping
,
dst
)
),
F
.
gather_row
(
global_nid_mapping
,
dst
)
global_eids
=
F
.
gather_row
(
local_g
.
edata
[
EID
],
sampled_graph
.
edata
[
EID
])
global_eids
=
F
.
gather_row
(
local_g
.
edata
[
EID
],
sampled_graph
.
edata
[
EID
])
return
global_src
,
global_dst
,
global_eids
return
LocalSampledGraph
(
global_src
,
global_dst
,
global_eids
)
def
_find_edges
(
local_g
,
partition_book
,
seed_edges
):
def
_find_edges
(
local_g
,
partition_book
,
seed_edges
):
...
@@ -257,7 +314,7 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
...
@@ -257,7 +314,7 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
src
,
dst
=
sampled_graph
.
edges
()
src
,
dst
=
sampled_graph
.
edges
()
global_src
,
global_dst
=
global_nid_mapping
[
src
],
global_nid_mapping
[
dst
]
global_src
,
global_dst
=
global_nid_mapping
[
src
],
global_nid_mapping
[
dst
]
global_eids
=
F
.
gather_row
(
local_g
.
edata
[
EID
],
sampled_graph
.
edata
[
EID
])
global_eids
=
F
.
gather_row
(
local_g
.
edata
[
EID
],
sampled_graph
.
edata
[
EID
])
return
global_src
,
global_dst
,
global_eids
return
LocalSampledGraph
(
global_src
,
global_dst
,
global_eids
)
# --- NOTE 1 ---
# --- NOTE 1 ---
...
@@ -333,26 +390,22 @@ class SamplingRequest(Request):
...
@@ -333,26 +390,22 @@ class SamplingRequest(Request):
prob
=
[
kv_store
.
data_store
[
self
.
prob
]]
prob
=
[
kv_store
.
data_store
[
self
.
prob
]]
else
:
else
:
prob
=
None
prob
=
None
if
self
.
use_graphbolt
:
res
=
_sample_neighbors
(
global_src
,
global_dst
,
etype_ids
=
_sample_neighbors_graphbolt
(
self
.
use_graphbolt
,
local_g
,
partition_book
,
self
.
seed_nodes
,
self
.
fan_out
,
prob
,
self
.
replace
,
)
return
SubgraphResponse
(
global_src
,
global_dst
,
etype_ids
)
global_src
,
global_dst
,
global_eids
=
_sample_neighbors
(
local_g
,
local_g
,
partition_book
,
partition_book
,
self
.
seed_nodes
,
self
.
seed_nodes
,
self
.
fan_out
,
self
.
fan_out
,
self
.
edge_dir
,
edge_dir
=
self
.
edge_dir
,
prob
,
prob
=
prob
,
self
.
replace
,
replace
=
self
.
replace
,
)
return
SubgraphResponse
(
res
.
global_src
,
res
.
global_dst
,
global_eids
=
res
.
global_eids
,
etype_ids
=
res
.
etype_ids
,
)
)
return
SubgraphResponse
(
global_src
,
global_dst
,
global_eids
)
class
SamplingRequestEtype
(
Request
):
class
SamplingRequestEtype
(
Request
):
...
@@ -407,7 +460,7 @@ class SamplingRequestEtype(Request):
...
@@ -407,7 +460,7 @@ class SamplingRequestEtype(Request):
]
]
else
:
else
:
probs
=
None
probs
=
None
global_src
,
global_dst
,
global_eid
s
=
_sample_etype_neighbors
(
re
s
=
_sample_etype_neighbors
(
local_g
,
local_g
,
partition_book
,
partition_book
,
self
.
seed_nodes
,
self
.
seed_nodes
,
...
@@ -418,7 +471,12 @@ class SamplingRequestEtype(Request):
...
@@ -418,7 +471,12 @@ class SamplingRequestEtype(Request):
self
.
replace
,
self
.
replace
,
self
.
etype_sorted
,
self
.
etype_sorted
,
)
)
return
SubgraphResponse
(
global_src
,
global_dst
,
global_eids
)
return
SubgraphResponse
(
res
.
global_src
,
res
.
global_dst
,
global_eids
=
res
.
global_eids
,
etype_ids
=
res
.
etype_ids
,
)
class
EdgesRequest
(
Request
):
class
EdgesRequest
(
Request
):
...
@@ -532,7 +590,7 @@ class InSubgraphRequest(Request):
...
@@ -532,7 +590,7 @@ class InSubgraphRequest(Request):
global_src
,
global_dst
,
global_eids
=
_in_subgraph
(
global_src
,
global_dst
,
global_eids
=
_in_subgraph
(
local_g
,
partition_book
,
self
.
seed_nodes
local_g
,
partition_book
,
self
.
seed_nodes
)
)
return
SubgraphResponse
(
global_src
,
global_dst
,
global_eids
)
return
SubgraphResponse
(
global_src
,
global_dst
,
global_eids
=
global_eids
)
def
merge_graphs
(
res_list
,
num_nodes
):
def
merge_graphs
(
res_list
,
num_nodes
):
...
@@ -541,25 +599,33 @@ def merge_graphs(res_list, num_nodes):
...
@@ -541,25 +599,33 @@ def merge_graphs(res_list, num_nodes):
srcs
=
[]
srcs
=
[]
dsts
=
[]
dsts
=
[]
eids
=
[]
eids
=
[]
etype_ids
=
[]
for
res
in
res_list
:
for
res
in
res_list
:
srcs
.
append
(
res
.
global_src
)
srcs
.
append
(
res
.
global_src
)
dsts
.
append
(
res
.
global_dst
)
dsts
.
append
(
res
.
global_dst
)
eids
.
append
(
res
.
global_eids
)
eids
.
append
(
res
.
global_eids
)
etype_ids
.
append
(
res
.
etype_ids
)
src_tensor
=
F
.
cat
(
srcs
,
0
)
src_tensor
=
F
.
cat
(
srcs
,
0
)
dst_tensor
=
F
.
cat
(
dsts
,
0
)
dst_tensor
=
F
.
cat
(
dsts
,
0
)
eid_tensor
=
None
if
eids
[
0
]
is
None
else
F
.
cat
(
eids
,
0
)
eid_tensor
=
None
if
eids
[
0
]
is
None
else
F
.
cat
(
eids
,
0
)
etype_id_tensor
=
None
if
etype_ids
[
0
]
is
None
else
F
.
cat
(
etype_ids
,
0
)
else
:
else
:
src_tensor
=
res_list
[
0
].
global_src
src_tensor
=
res_list
[
0
].
global_src
dst_tensor
=
res_list
[
0
].
global_dst
dst_tensor
=
res_list
[
0
].
global_dst
eid_tensor
=
res_list
[
0
].
global_eids
eid_tensor
=
res_list
[
0
].
global_eids
etype_id_tensor
=
res_list
[
0
].
etype_ids
g
=
graph
((
src_tensor
,
dst_tensor
),
num_nodes
=
num_nodes
)
g
=
graph
((
src_tensor
,
dst_tensor
),
num_nodes
=
num_nodes
)
if
eid_tensor
is
not
None
:
if
eid_tensor
is
not
None
:
g
.
edata
[
EID
]
=
eid_tensor
g
.
edata
[
EID
]
=
eid_tensor
if
etype_id_tensor
is
not
None
:
g
.
edata
[
ETYPE
]
=
etype_id_tensor
return
g
return
g
LocalSampledGraph
=
namedtuple
(
LocalSampledGraph
=
namedtuple
(
# pylint: disable=unexpected-keyword-arg
"LocalSampledGraph"
,
"global_src global_dst global_eids"
"LocalSampledGraph"
,
"global_src global_dst global_eids etype_ids"
,
defaults
=
(
None
,
None
,
None
,
None
),
)
)
...
@@ -615,10 +681,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
...
@@ -615,10 +681,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
# sample neighbors for the nodes in the local partition.
# sample neighbors for the nodes in the local partition.
res_list
=
[]
res_list
=
[]
if
local_nids
is
not
None
:
if
local_nids
is
not
None
:
src
,
dst
,
eids
=
local_access
(
res
=
local_access
(
g
.
local_partition
,
partition_book
,
local_nids
)
g
.
local_partition
,
partition_book
,
local_nids
res_list
.
append
(
res
)
)
res_list
.
append
(
LocalSampledGraph
(
src
,
dst
,
eids
))
# receive responses from remote machines.
# receive responses from remote machines.
if
msgseq2pos
is
not
None
:
if
msgseq2pos
is
not
None
:
...
@@ -916,23 +980,15 @@ def sample_neighbors(
...
@@ -916,23 +980,15 @@ def sample_neighbors(
def
local_access
(
local_g
,
partition_book
,
local_nids
):
def
local_access
(
local_g
,
partition_book
,
local_nids
):
# See NOTE 1
# See NOTE 1
_prob
=
[
g
.
edata
[
prob
].
local_partition
]
if
prob
is
not
None
else
None
_prob
=
[
g
.
edata
[
prob
].
local_partition
]
if
prob
is
not
None
else
None
if
use_graphbolt
:
return
_sample_neighbors_graphbolt
(
local_g
,
partition_book
,
local_nids
,
fanout
,
prob
=
_prob
,
replace
=
replace
,
)
return
_sample_neighbors
(
return
_sample_neighbors
(
use_graphbolt
,
local_g
,
local_g
,
partition_book
,
partition_book
,
local_nids
,
local_nids
,
fanout
,
fanout
,
edge_dir
,
edge_dir
=
edge_dir
,
_prob
,
prob
=
_prob
,
replace
,
replace
=
replace
,
)
)
frontier
=
_distributed_access
(
g
,
nodes
,
issue_remote_req
,
local_access
)
frontier
=
_distributed_access
(
g
,
nodes
,
issue_remote_req
,
local_access
)
...
...
tests/distributed/test_distributed_sampling.py
View file @
f3af2a9f
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
os
import
os
import
random
import
random
import
sys
import
tempfile
import
time
import
time
import
traceback
import
traceback
import
unittest
import
unittest
...
@@ -1013,47 +1013,85 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
...
@@ -1013,47 +1013,85 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
assert
np
.
all
(
F
.
asnumpy
(
orig_dst1
)
==
orig_dst
)
assert
np
.
all
(
F
.
asnumpy
(
orig_dst1
)
==
orig_dst
)
# Wait non shared memory graph store
@
unittest
.
skipIf
(
os
.
name
==
"nt"
,
reason
=
"Do not support windows yet"
)
@
unittest
.
skipIf
(
dgl
.
backend
.
backend_name
==
"tensorflow"
,
reason
=
"Not support tensorflow for now"
,
)
@
unittest
.
skipIf
(
dgl
.
backend
.
backend_name
==
"mxnet"
,
reason
=
"Turn off Mxnet support"
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_sampling_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
def
test_rpc_sampling_shuffle
(
num_server
,
use_graphbolt
):
reset_envs
()
reset_envs
()
import
tempfile
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_sampling_shuffle
(
check_rpc_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
True
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
)
)
check_rpc_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
# [TODO][Rhett] Tests for multiple groups may fail sometimes and
# root cause is unknown. Let's disable them for now.
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
# check_rpc_sampling_shuffle(Path(tmpdirname), num_server, num_groups=2)
def
test_rpc_hetero_sampling_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_hetero_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_hetero_sampling_empty_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_sampling_empty_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_hetero_sampling_empty_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_hetero_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_hetero_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
[
"csc"
]
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
)
@
pytest
.
mark
.
parametrize
(
check_rpc_hetero_etype_sampling_shuffle
(
"graph_formats"
,
[
None
,
[
"csc"
],
[
"csr"
],
[
"csc"
,
"coo"
]]
Path
(
tmpdirname
),
num_server
,
[
"csr"
]
)
)
def
test_rpc_hetero_etype_sampling_shuffle
(
num_server
,
graph_formats
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_etype_sampling_shuffle
(
check_rpc_hetero_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
[
"csc"
,
"coo"
]
Path
(
tmpdirname
),
num_server
,
graph_formats
=
graph_formats
)
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_hetero_etype_sampling_empty_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_etype_sampling_empty_shuffle
(
check_rpc_hetero_etype_sampling_empty_shuffle
(
Path
(
tmpdirname
),
num_server
Path
(
tmpdirname
),
num_server
)
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_sampling_empty_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_sampling_empty
(
Path
(
tmpdirname
),
num_server
)
check_rpc_bipartite_sampling_empty
(
Path
(
tmpdirname
),
num_server
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_sampling_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_bipartite_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_etype_sampling_empty_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_etype_sampling_empty
(
Path
(
tmpdirname
),
num_server
)
check_rpc_bipartite_etype_sampling_empty
(
Path
(
tmpdirname
),
num_server
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_etype_sampling_shuffle
(
num_server
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_bipartite_etype_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment