Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ee8b7b39
Unverified
Commit
ee8b7b39
authored
Feb 06, 2024
by
Rhett Ying
Committed by
GitHub
Feb 06, 2024
Browse files
[DistGB] enable GB sampling on heterograph (#7087)
parent
a2e1c796
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
196 additions
and
40 deletions
+196
-40
python/dgl/distributed/graph_services.py
python/dgl/distributed/graph_services.py
+30
-6
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+166
-34
No files found.
python/dgl/distributed/graph_services.py
View file @
ee8b7b39
"""A set of graph services of getting subgraphs from DistGraph"""
"""A set of graph services of getting subgraphs from DistGraph"""
import
os
from
collections
import
namedtuple
from
collections
import
namedtuple
import
numpy
as
np
import
numpy
as
np
...
@@ -708,24 +709,47 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
...
@@ -708,24 +709,47 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
idtype
=
g
.
idtype
,
idtype
=
g
.
idtype
,
)
)
etype_ids
,
frontier
.
edata
[
EID
]
=
gpb
.
map_to_per_etype
(
frontier
.
edata
[
EID
])
# For DGL partitions, the global edge IDs are always stored in the edata.
src
,
dst
=
frontier
.
edges
()
# For GraphBolt partitions, the edge type IDs are always stored in the
# edata. As for the edge IDs, they are stored in the edata if the graph is
# partitioned with `store_eids=True`. Otherwise, the edge IDs are not
# stored.
etype_ids
,
type_wise_eids
=
(
gpb
.
map_to_per_etype
(
frontier
.
edata
[
EID
])
if
EID
in
frontier
.
edata
else
(
frontier
.
edata
[
ETYPE
],
None
)
)
etype_ids
,
idx
=
F
.
sort_1d
(
etype_ids
)
etype_ids
,
idx
=
F
.
sort_1d
(
etype_ids
)
if
type_wise_eids
is
not
None
:
type_wise_eids
=
F
.
gather_row
(
type_wise_eids
,
idx
)
# Sort the edges by their edge types.
src
,
dst
=
frontier
.
edges
()
src
,
dst
=
F
.
gather_row
(
src
,
idx
),
F
.
gather_row
(
dst
,
idx
)
src
,
dst
=
F
.
gather_row
(
src
,
idx
),
F
.
gather_row
(
dst
,
idx
)
eid
=
F
.
gather_row
(
frontier
.
edata
[
EID
],
idx
)
src_ntype_ids
,
src
=
gpb
.
map_to_per_ntype
(
src
)
_
,
src
=
gpb
.
map_to_per_ntype
(
src
)
dst_ntype_ids
,
dst
=
gpb
.
map_to_per_ntype
(
dst
)
_
,
dst
=
gpb
.
map_to_per_ntype
(
dst
)
data_dict
=
dict
()
data_dict
=
dict
()
edge_ids
=
{}
edge_ids
=
{}
for
etid
,
etype
in
enumerate
(
g
.
canonical_etypes
):
for
etid
,
etype
in
enumerate
(
g
.
canonical_etypes
):
src_ntype
,
_
,
dst_ntype
=
etype
src_ntype_id
=
g
.
get_ntype_id
(
src_ntype
)
dst_ntype_id
=
g
.
get_ntype_id
(
dst_ntype
)
type_idx
=
etype_ids
==
etid
type_idx
=
etype_ids
==
etid
if
F
.
sum
(
type_idx
,
0
)
>
0
:
if
F
.
sum
(
type_idx
,
0
)
>
0
:
data_dict
[
etype
]
=
(
data_dict
[
etype
]
=
(
F
.
boolean_mask
(
src
,
type_idx
),
F
.
boolean_mask
(
src
,
type_idx
),
F
.
boolean_mask
(
dst
,
type_idx
),
F
.
boolean_mask
(
dst
,
type_idx
),
)
)
edge_ids
[
etype
]
=
F
.
boolean_mask
(
eid
,
type_idx
)
if
"DGL_DIST_DEBUG"
in
os
.
environ
:
assert
torch
.
all
(
src_ntype_id
==
src_ntype_ids
[
type_idx
]
),
"source ntype is is not expected."
assert
torch
.
all
(
dst_ntype_id
==
dst_ntype_ids
[
type_idx
]
),
"destination ntype is is not expected."
if
type_wise_eids
is
not
None
:
edge_ids
[
etype
]
=
F
.
boolean_mask
(
type_wise_eids
,
type_idx
)
hg
=
heterograph
(
hg
=
heterograph
(
data_dict
,
data_dict
,
{
ntype
:
g
.
num_nodes
(
ntype
)
for
ntype
in
g
.
ntypes
},
{
ntype
:
g
.
num_nodes
(
ntype
)
for
ntype
in
g
.
ntypes
},
...
...
tests/distributed/test_distributed_sampling.py
View file @
ee8b7b39
...
@@ -91,6 +91,9 @@ def start_sample_client_shuffle(
...
@@ -91,6 +91,9 @@ def start_sample_client_shuffle(
dist_graph
,
[
0
,
10
,
99
,
66
,
1024
,
2008
],
3
,
use_graphbolt
=
use_graphbolt
dist_graph
,
[
0
,
10
,
99
,
66
,
1024
,
2008
],
3
,
use_graphbolt
=
use_graphbolt
)
)
assert
(
dgl
.
ETYPE
not
in
sampled_graph
.
edata
),
"Etype should not be in homogeneous sampled graph."
src
,
dst
=
sampled_graph
.
edges
()
src
,
dst
=
sampled_graph
.
edges
()
src
=
orig_nid
[
src
]
src
=
orig_nid
[
src
]
dst
=
orig_nid
[
dst
]
dst
=
orig_nid
[
dst
]
...
@@ -460,23 +463,37 @@ def check_rpc_sampling_shuffle(
...
@@ -460,23 +463,37 @@ def check_rpc_sampling_shuffle(
assert
p
.
exitcode
==
0
assert
p
.
exitcode
==
0
def
start_hetero_sample_client
(
rank
,
tmpdir
,
disable_shared_mem
,
nodes
):
def
start_hetero_sample_client
(
rank
,
tmpdir
,
disable_shared_mem
,
nodes
,
use_graphbolt
=
False
,
return_eids
=
False
,
):
gpb
=
None
gpb
=
None
if
disable_shared_mem
:
if
disable_shared_mem
:
_
,
_
,
_
,
gpb
,
_
,
_
,
_
=
load_partition
(
_
,
_
,
_
,
gpb
,
_
,
_
,
_
=
load_partition
(
tmpdir
/
"test_sampling.json"
,
rank
tmpdir
/
"test_sampling.json"
,
rank
)
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
if
gpb
is
None
:
if
gpb
is
None
:
gpb
=
dist_graph
.
get_partition_book
()
gpb
=
dist_graph
.
get_partition_book
()
try
:
try
:
sampled_graph
=
sample_neighbors
(
dist_graph
,
nodes
,
3
)
# Enable santity check in distributed sampling.
os
.
environ
[
"DGL_DIST_DEBUG"
]
=
"1"
sampled_graph
=
sample_neighbors
(
dist_graph
,
nodes
,
3
,
use_graphbolt
=
use_graphbolt
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
block
.
edata
[
dgl
.
EID
]
=
sampled_graph
.
edata
[
dgl
.
EID
]
if
not
use_graphbolt
or
return_eids
:
block
.
edata
[
dgl
.
EID
]
=
sampled_graph
.
edata
[
dgl
.
EID
]
except
Exception
as
e
:
except
Exception
as
e
:
print
(
traceback
.
format_exc
())
print
(
traceback
.
format_exc
())
block
=
None
block
=
None
...
@@ -528,7 +545,9 @@ def start_hetero_etype_sample_client(
...
@@ -528,7 +545,9 @@ def start_hetero_etype_sample_client(
return
block
,
gpb
return
block
,
gpb
def
check_rpc_hetero_sampling_shuffle
(
tmpdir
,
num_server
):
def
check_rpc_hetero_sampling_shuffle
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
g
=
create_random_hetero
()
g
=
create_random_hetero
()
...
@@ -543,6 +562,8 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
...
@@ -543,6 +562,8 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
num_hops
=
num_hops
,
num_hops
=
num_hops
,
part_method
=
"metis"
,
part_method
=
"metis"
,
return_mapping
=
True
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
)
pserver_list
=
[]
pserver_list
=
[]
...
@@ -550,16 +571,27 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
...
@@ -550,16 +571,27 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
for
i
in
range
(
num_server
):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
p
=
ctx
.
Process
(
target
=
start_server
,
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
)
p
.
start
()
p
.
start
()
time
.
sleep
(
1
)
time
.
sleep
(
1
)
pserver_list
.
append
(
p
)
pserver_list
.
append
(
p
)
block
,
gpb
=
start_hetero_sample_client
(
block
,
gpb
=
start_hetero_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"n3"
:
[
0
,
10
,
99
,
66
,
124
,
208
]}
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"n3"
:
[
0
,
10
,
99
,
66
,
124
,
208
]},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
)
print
(
"Done sampling"
)
for
p
in
pserver_list
:
for
p
in
pserver_list
:
p
.
join
()
p
.
join
()
assert
p
.
exitcode
==
0
assert
p
.
exitcode
==
0
...
@@ -570,10 +602,17 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
...
@@ -570,10 +602,17 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
# These are global Ids after shuffling.
# These are global Ids after shuffling.
shuffled_src
=
F
.
gather_row
(
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
],
src
)
shuffled_src
=
F
.
gather_row
(
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
],
src
)
shuffled_dst
=
F
.
gather_row
(
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
],
dst
)
shuffled_dst
=
F
.
gather_row
(
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
],
dst
)
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_src
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
src_type
],
shuffled_src
))
orig_src
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
src_type
],
shuffled_src
))
orig_dst
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
dst_type
],
shuffled_dst
))
orig_dst
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
dst_type
],
shuffled_dst
))
assert
np
.
all
(
F
.
asnumpy
(
g
.
has_edges_between
(
orig_src
,
orig_dst
,
etype
=
etype
))
)
if
use_graphbolt
and
not
return_eids
:
continue
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
# Check the node Ids and edge Ids.
# Check the node Ids and edge Ids.
...
@@ -592,7 +631,9 @@ def get_degrees(g, nids, ntype):
...
@@ -592,7 +631,9 @@ def get_degrees(g, nids, ntype):
return
deg
return
deg
def
check_rpc_hetero_sampling_empty_shuffle
(
tmpdir
,
num_server
):
def
check_rpc_hetero_sampling_empty_shuffle
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
g
=
create_random_hetero
(
empty
=
True
)
g
=
create_random_hetero
(
empty
=
True
)
...
@@ -607,6 +648,8 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
...
@@ -607,6 +648,8 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
num_hops
=
num_hops
,
num_hops
=
num_hops
,
part_method
=
"metis"
,
part_method
=
"metis"
,
return_mapping
=
True
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
)
pserver_list
=
[]
pserver_list
=
[]
...
@@ -614,7 +657,14 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
...
@@ -614,7 +657,14 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
for
i
in
range
(
num_server
):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
p
=
ctx
.
Process
(
target
=
start_server
,
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
)
p
.
start
()
p
.
start
()
time
.
sleep
(
1
)
time
.
sleep
(
1
)
...
@@ -623,9 +673,13 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
...
@@ -623,9 +673,13 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
deg
=
get_degrees
(
g
,
orig_nids
[
"n3"
],
"n3"
)
deg
=
get_degrees
(
g
,
orig_nids
[
"n3"
],
"n3"
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
block
,
gpb
=
start_hetero_sample_client
(
block
,
gpb
=
start_hetero_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"n3"
:
empty_nids
}
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"n3"
:
empty_nids
},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
)
print
(
"Done sampling"
)
for
p
in
pserver_list
:
for
p
in
pserver_list
:
p
.
join
()
p
.
join
()
assert
p
.
exitcode
==
0
assert
p
.
exitcode
==
0
...
@@ -759,22 +813,36 @@ def create_random_bipartite():
...
@@ -759,22 +813,36 @@ def create_random_bipartite():
return
g
return
g
def
start_bipartite_sample_client
(
rank
,
tmpdir
,
disable_shared_mem
,
nodes
):
def
start_bipartite_sample_client
(
rank
,
tmpdir
,
disable_shared_mem
,
nodes
,
use_graphbolt
=
False
,
return_eids
=
False
,
):
gpb
=
None
gpb
=
None
if
disable_shared_mem
:
if
disable_shared_mem
:
_
,
_
,
_
,
gpb
,
_
,
_
,
_
=
load_partition
(
_
,
_
,
_
,
gpb
,
_
,
_
,
_
=
load_partition
(
tmpdir
/
"test_sampling.json"
,
rank
tmpdir
/
"test_sampling.json"
,
rank
)
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
if
gpb
is
None
:
if
gpb
is
None
:
gpb
=
dist_graph
.
get_partition_book
()
gpb
=
dist_graph
.
get_partition_book
()
sampled_graph
=
sample_neighbors
(
dist_graph
,
nodes
,
3
)
# Enable santity check in distributed sampling.
os
.
environ
[
"DGL_DIST_DEBUG"
]
=
"1"
sampled_graph
=
sample_neighbors
(
dist_graph
,
nodes
,
3
,
use_graphbolt
=
use_graphbolt
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
if
sampled_graph
.
num_edges
()
>
0
:
if
sampled_graph
.
num_edges
()
>
0
:
block
.
edata
[
dgl
.
EID
]
=
sampled_graph
.
edata
[
dgl
.
EID
]
if
not
use_graphbolt
or
return_eids
:
block
.
edata
[
dgl
.
EID
]
=
sampled_graph
.
edata
[
dgl
.
EID
]
dgl
.
distributed
.
exit_client
()
dgl
.
distributed
.
exit_client
()
return
block
,
gpb
return
block
,
gpb
...
@@ -812,7 +880,9 @@ def start_bipartite_etype_sample_client(
...
@@ -812,7 +880,9 @@ def start_bipartite_etype_sample_client(
return
block
,
gpb
return
block
,
gpb
def
check_rpc_bipartite_sampling_empty
(
tmpdir
,
num_server
):
def
check_rpc_bipartite_sampling_empty
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
"""sample on bipartite via sample_neighbors() which yields empty sample results"""
"""sample on bipartite via sample_neighbors() which yields empty sample results"""
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
@@ -828,6 +898,8 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
...
@@ -828,6 +898,8 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
num_hops
=
num_hops
,
num_hops
=
num_hops
,
part_method
=
"metis"
,
part_method
=
"metis"
,
return_mapping
=
True
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
)
pserver_list
=
[]
pserver_list
=
[]
...
@@ -835,7 +907,14 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
...
@@ -835,7 +907,14 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
for
i
in
range
(
num_server
):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
p
=
ctx
.
Process
(
target
=
start_server
,
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
)
p
.
start
()
p
.
start
()
time
.
sleep
(
1
)
time
.
sleep
(
1
)
...
@@ -844,7 +923,12 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
...
@@ -844,7 +923,12 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
deg
=
get_degrees
(
g
,
orig_nids
[
"game"
],
"game"
)
deg
=
get_degrees
(
g
,
orig_nids
[
"game"
],
"game"
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
block
,
_
=
start_bipartite_sample_client
(
block
,
_
=
start_bipartite_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"game"
:
empty_nids
,
"user"
:
[
1
]}
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"game"
:
empty_nids
,
"user"
:
[
1
]},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
)
print
(
"Done sampling"
)
print
(
"Done sampling"
)
...
@@ -856,7 +940,9 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
...
@@ -856,7 +940,9 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
assert
len
(
block
.
etypes
)
==
len
(
g
.
etypes
)
assert
len
(
block
.
etypes
)
==
len
(
g
.
etypes
)
def
check_rpc_bipartite_sampling_shuffle
(
tmpdir
,
num_server
):
def
check_rpc_bipartite_sampling_shuffle
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
"""sample on bipartite via sample_neighbors() which yields non-empty sample results"""
"""sample on bipartite via sample_neighbors() which yields non-empty sample results"""
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
@@ -872,6 +958,8 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
...
@@ -872,6 +958,8 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
num_hops
=
num_hops
,
num_hops
=
num_hops
,
part_method
=
"metis"
,
part_method
=
"metis"
,
return_mapping
=
True
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
)
pserver_list
=
[]
pserver_list
=
[]
...
@@ -879,7 +967,14 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
...
@@ -879,7 +967,14 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
for
i
in
range
(
num_server
):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
p
=
ctx
.
Process
(
target
=
start_server
,
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
)
p
.
start
()
p
.
start
()
time
.
sleep
(
1
)
time
.
sleep
(
1
)
...
@@ -888,7 +983,12 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
...
@@ -888,7 +983,12 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
deg
=
get_degrees
(
g
,
orig_nid_map
[
"game"
],
"game"
)
deg
=
get_degrees
(
g
,
orig_nid_map
[
"game"
],
"game"
)
nids
=
F
.
nonzero_1d
(
deg
>
0
)
nids
=
F
.
nonzero_1d
(
deg
>
0
)
block
,
gpb
=
start_bipartite_sample_client
(
block
,
gpb
=
start_bipartite_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"game"
:
nids
,
"user"
:
[
0
]}
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"game"
:
nids
,
"user"
:
[
0
]},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
)
print
(
"Done sampling"
)
print
(
"Done sampling"
)
for
p
in
pserver_list
:
for
p
in
pserver_list
:
...
@@ -901,10 +1001,16 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
...
@@ -901,10 +1001,16 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
# These are global Ids after shuffling.
# These are global Ids after shuffling.
shuffled_src
=
F
.
gather_row
(
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
],
src
)
shuffled_src
=
F
.
gather_row
(
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
],
src
)
shuffled_dst
=
F
.
gather_row
(
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
],
dst
)
shuffled_dst
=
F
.
gather_row
(
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
],
dst
)
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_src
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
src_type
],
shuffled_src
))
orig_src
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
src_type
],
shuffled_src
))
orig_dst
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
dst_type
],
shuffled_dst
))
orig_dst
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
dst_type
],
shuffled_dst
))
assert
np
.
all
(
F
.
asnumpy
(
g
.
has_edges_between
(
orig_src
,
orig_dst
,
etype
=
etype
))
)
if
use_graphbolt
and
not
return_eids
:
continue
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
# Check the node Ids and edge Ids.
# Check the node Ids and edge Ids.
...
@@ -1032,19 +1138,35 @@ def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
...
@@ -1032,19 +1138,35 @@ def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_hetero_sampling_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt,"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_hetero_sampling_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_hetero_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_hetero_sampling_empty_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_hetero_sampling_empty_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_sampling_empty_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_hetero_sampling_empty_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
...
@@ -1071,19 +1193,29 @@ def test_rpc_hetero_etype_sampling_empty_shuffle(num_server):
...
@@ -1071,19 +1193,29 @@ def test_rpc_hetero_etype_sampling_empty_shuffle(num_server):
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_sampling_empty_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_bipartite_sampling_empty_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_sampling_empty
(
Path
(
tmpdirname
),
num_server
)
check_rpc_bipartite_sampling_empty
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
,
return_eids
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_sampling_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_bipartite_sampling_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_bipartite_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
,
return_eids
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment