Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ee8b7b39
Unverified
Commit
ee8b7b39
authored
Feb 06, 2024
by
Rhett Ying
Committed by
GitHub
Feb 06, 2024
Browse files
[DistGB] enable GB sampling on heterograph (#7087)
parent
a2e1c796
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
196 additions
and
40 deletions
+196
-40
python/dgl/distributed/graph_services.py
python/dgl/distributed/graph_services.py
+30
-6
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+166
-34
No files found.
python/dgl/distributed/graph_services.py
View file @
ee8b7b39
"""A set of graph services of getting subgraphs from DistGraph"""
import
os
from
collections
import
namedtuple
import
numpy
as
np
...
...
@@ -708,24 +709,47 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
idtype
=
g
.
idtype
,
)
etype_ids
,
frontier
.
edata
[
EID
]
=
gpb
.
map_to_per_etype
(
frontier
.
edata
[
EID
])
src
,
dst
=
frontier
.
edges
()
# For DGL partitions, the global edge IDs are always stored in the edata.
# For GraphBolt partitions, the edge type IDs are always stored in the
# edata. As for the edge IDs, they are stored in the edata if the graph is
# partitioned with `store_eids=True`. Otherwise, the edge IDs are not
# stored.
etype_ids
,
type_wise_eids
=
(
gpb
.
map_to_per_etype
(
frontier
.
edata
[
EID
])
if
EID
in
frontier
.
edata
else
(
frontier
.
edata
[
ETYPE
],
None
)
)
etype_ids
,
idx
=
F
.
sort_1d
(
etype_ids
)
if
type_wise_eids
is
not
None
:
type_wise_eids
=
F
.
gather_row
(
type_wise_eids
,
idx
)
# Sort the edges by their edge types.
src
,
dst
=
frontier
.
edges
()
src
,
dst
=
F
.
gather_row
(
src
,
idx
),
F
.
gather_row
(
dst
,
idx
)
eid
=
F
.
gather_row
(
frontier
.
edata
[
EID
],
idx
)
_
,
src
=
gpb
.
map_to_per_ntype
(
src
)
_
,
dst
=
gpb
.
map_to_per_ntype
(
dst
)
src_ntype_ids
,
src
=
gpb
.
map_to_per_ntype
(
src
)
dst_ntype_ids
,
dst
=
gpb
.
map_to_per_ntype
(
dst
)
data_dict
=
dict
()
edge_ids
=
{}
for
etid
,
etype
in
enumerate
(
g
.
canonical_etypes
):
src_ntype
,
_
,
dst_ntype
=
etype
src_ntype_id
=
g
.
get_ntype_id
(
src_ntype
)
dst_ntype_id
=
g
.
get_ntype_id
(
dst_ntype
)
type_idx
=
etype_ids
==
etid
if
F
.
sum
(
type_idx
,
0
)
>
0
:
data_dict
[
etype
]
=
(
F
.
boolean_mask
(
src
,
type_idx
),
F
.
boolean_mask
(
dst
,
type_idx
),
)
edge_ids
[
etype
]
=
F
.
boolean_mask
(
eid
,
type_idx
)
if
"DGL_DIST_DEBUG"
in
os
.
environ
:
assert
torch
.
all
(
src_ntype_id
==
src_ntype_ids
[
type_idx
]
),
"source ntype is is not expected."
assert
torch
.
all
(
dst_ntype_id
==
dst_ntype_ids
[
type_idx
]
),
"destination ntype is is not expected."
if
type_wise_eids
is
not
None
:
edge_ids
[
etype
]
=
F
.
boolean_mask
(
type_wise_eids
,
type_idx
)
hg
=
heterograph
(
data_dict
,
{
ntype
:
g
.
num_nodes
(
ntype
)
for
ntype
in
g
.
ntypes
},
...
...
tests/distributed/test_distributed_sampling.py
View file @
ee8b7b39
...
...
@@ -91,6 +91,9 @@ def start_sample_client_shuffle(
dist_graph
,
[
0
,
10
,
99
,
66
,
1024
,
2008
],
3
,
use_graphbolt
=
use_graphbolt
)
assert
(
dgl
.
ETYPE
not
in
sampled_graph
.
edata
),
"Etype should not be in homogeneous sampled graph."
src
,
dst
=
sampled_graph
.
edges
()
src
=
orig_nid
[
src
]
dst
=
orig_nid
[
dst
]
...
...
@@ -460,22 +463,36 @@ def check_rpc_sampling_shuffle(
assert
p
.
exitcode
==
0
def
start_hetero_sample_client
(
rank
,
tmpdir
,
disable_shared_mem
,
nodes
):
def
start_hetero_sample_client
(
rank
,
tmpdir
,
disable_shared_mem
,
nodes
,
use_graphbolt
=
False
,
return_eids
=
False
,
):
gpb
=
None
if
disable_shared_mem
:
_
,
_
,
_
,
gpb
,
_
,
_
,
_
=
load_partition
(
tmpdir
/
"test_sampling.json"
,
rank
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
if
gpb
is
None
:
gpb
=
dist_graph
.
get_partition_book
()
try
:
sampled_graph
=
sample_neighbors
(
dist_graph
,
nodes
,
3
)
# Enable santity check in distributed sampling.
os
.
environ
[
"DGL_DIST_DEBUG"
]
=
"1"
sampled_graph
=
sample_neighbors
(
dist_graph
,
nodes
,
3
,
use_graphbolt
=
use_graphbolt
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
if
not
use_graphbolt
or
return_eids
:
block
.
edata
[
dgl
.
EID
]
=
sampled_graph
.
edata
[
dgl
.
EID
]
except
Exception
as
e
:
print
(
traceback
.
format_exc
())
...
...
@@ -528,7 +545,9 @@ def start_hetero_etype_sample_client(
return
block
,
gpb
def
check_rpc_hetero_sampling_shuffle
(
tmpdir
,
num_server
):
def
check_rpc_hetero_sampling_shuffle
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
g
=
create_random_hetero
()
...
...
@@ -543,6 +562,8 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
num_hops
=
num_hops
,
part_method
=
"metis"
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
pserver_list
=
[]
...
...
@@ -550,16 +571,27 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
p
.
start
()
time
.
sleep
(
1
)
pserver_list
.
append
(
p
)
block
,
gpb
=
start_hetero_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"n3"
:
[
0
,
10
,
99
,
66
,
124
,
208
]}
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"n3"
:
[
0
,
10
,
99
,
66
,
124
,
208
]},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
print
(
"Done sampling"
)
for
p
in
pserver_list
:
p
.
join
()
assert
p
.
exitcode
==
0
...
...
@@ -570,10 +602,17 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
# These are global Ids after shuffling.
shuffled_src
=
F
.
gather_row
(
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
],
src
)
shuffled_dst
=
F
.
gather_row
(
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
],
dst
)
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_src
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
src_type
],
shuffled_src
))
orig_dst
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
dst_type
],
shuffled_dst
))
assert
np
.
all
(
F
.
asnumpy
(
g
.
has_edges_between
(
orig_src
,
orig_dst
,
etype
=
etype
))
)
if
use_graphbolt
and
not
return_eids
:
continue
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
# Check the node Ids and edge Ids.
...
...
@@ -592,7 +631,9 @@ def get_degrees(g, nids, ntype):
return
deg
def
check_rpc_hetero_sampling_empty_shuffle
(
tmpdir
,
num_server
):
def
check_rpc_hetero_sampling_empty_shuffle
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
g
=
create_random_hetero
(
empty
=
True
)
...
...
@@ -607,6 +648,8 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
num_hops
=
num_hops
,
part_method
=
"metis"
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
pserver_list
=
[]
...
...
@@ -614,7 +657,14 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
p
.
start
()
time
.
sleep
(
1
)
...
...
@@ -623,9 +673,13 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
deg
=
get_degrees
(
g
,
orig_nids
[
"n3"
],
"n3"
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
block
,
gpb
=
start_hetero_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"n3"
:
empty_nids
}
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"n3"
:
empty_nids
},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
print
(
"Done sampling"
)
for
p
in
pserver_list
:
p
.
join
()
assert
p
.
exitcode
==
0
...
...
@@ -759,21 +813,35 @@ def create_random_bipartite():
return
g
def
start_bipartite_sample_client
(
rank
,
tmpdir
,
disable_shared_mem
,
nodes
):
def
start_bipartite_sample_client
(
rank
,
tmpdir
,
disable_shared_mem
,
nodes
,
use_graphbolt
=
False
,
return_eids
=
False
,
):
gpb
=
None
if
disable_shared_mem
:
_
,
_
,
_
,
gpb
,
_
,
_
,
_
=
load_partition
(
tmpdir
/
"test_sampling.json"
,
rank
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
if
gpb
is
None
:
gpb
=
dist_graph
.
get_partition_book
()
sampled_graph
=
sample_neighbors
(
dist_graph
,
nodes
,
3
)
# Enable santity check in distributed sampling.
os
.
environ
[
"DGL_DIST_DEBUG"
]
=
"1"
sampled_graph
=
sample_neighbors
(
dist_graph
,
nodes
,
3
,
use_graphbolt
=
use_graphbolt
)
block
=
dgl
.
to_block
(
sampled_graph
,
nodes
)
if
sampled_graph
.
num_edges
()
>
0
:
if
not
use_graphbolt
or
return_eids
:
block
.
edata
[
dgl
.
EID
]
=
sampled_graph
.
edata
[
dgl
.
EID
]
dgl
.
distributed
.
exit_client
()
return
block
,
gpb
...
...
@@ -812,7 +880,9 @@ def start_bipartite_etype_sample_client(
return
block
,
gpb
def
check_rpc_bipartite_sampling_empty
(
tmpdir
,
num_server
):
def
check_rpc_bipartite_sampling_empty
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
"""sample on bipartite via sample_neighbors() which yields empty sample results"""
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
...
@@ -828,6 +898,8 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
num_hops
=
num_hops
,
part_method
=
"metis"
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
pserver_list
=
[]
...
...
@@ -835,7 +907,14 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
p
.
start
()
time
.
sleep
(
1
)
...
...
@@ -844,7 +923,12 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
deg
=
get_degrees
(
g
,
orig_nids
[
"game"
],
"game"
)
empty_nids
=
F
.
nonzero_1d
(
deg
==
0
)
block
,
_
=
start_bipartite_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"game"
:
empty_nids
,
"user"
:
[
1
]}
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"game"
:
empty_nids
,
"user"
:
[
1
]},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
print
(
"Done sampling"
)
...
...
@@ -856,7 +940,9 @@ def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
assert
len
(
block
.
etypes
)
==
len
(
g
.
etypes
)
def
check_rpc_bipartite_sampling_shuffle
(
tmpdir
,
num_server
):
def
check_rpc_bipartite_sampling_shuffle
(
tmpdir
,
num_server
,
use_graphbolt
=
False
,
return_eids
=
False
):
"""sample on bipartite via sample_neighbors() which yields non-empty sample results"""
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
...
...
@@ -872,6 +958,8 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
num_hops
=
num_hops
,
part_method
=
"metis"
,
return_mapping
=
True
,
use_graphbolt
=
use_graphbolt
,
store_eids
=
return_eids
,
)
pserver_list
=
[]
...
...
@@ -879,7 +967,14 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
),
args
=
(
i
,
tmpdir
,
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
use_graphbolt
,
),
)
p
.
start
()
time
.
sleep
(
1
)
...
...
@@ -888,7 +983,12 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
deg
=
get_degrees
(
g
,
orig_nid_map
[
"game"
],
"game"
)
nids
=
F
.
nonzero_1d
(
deg
>
0
)
block
,
gpb
=
start_bipartite_sample_client
(
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"game"
:
nids
,
"user"
:
[
0
]}
0
,
tmpdir
,
num_server
>
1
,
nodes
=
{
"game"
:
nids
,
"user"
:
[
0
]},
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
print
(
"Done sampling"
)
for
p
in
pserver_list
:
...
...
@@ -901,10 +1001,16 @@ def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
# These are global Ids after shuffling.
shuffled_src
=
F
.
gather_row
(
block
.
srcnodes
[
src_type
].
data
[
dgl
.
NID
],
src
)
shuffled_dst
=
F
.
gather_row
(
block
.
dstnodes
[
dst_type
].
data
[
dgl
.
NID
],
dst
)
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_src
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
src_type
],
shuffled_src
))
orig_dst
=
F
.
asnumpy
(
F
.
gather_row
(
orig_nid_map
[
dst_type
],
shuffled_dst
))
assert
np
.
all
(
F
.
asnumpy
(
g
.
has_edges_between
(
orig_src
,
orig_dst
,
etype
=
etype
))
)
if
use_graphbolt
and
not
return_eids
:
continue
shuffled_eid
=
block
.
edges
[
etype
].
data
[
dgl
.
EID
]
orig_eid
=
F
.
asnumpy
(
F
.
gather_row
(
orig_eid_map
[
c_etype
],
shuffled_eid
))
# Check the node Ids and edge Ids.
...
...
@@ -1032,19 +1138,35 @@ def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_hetero_sampling_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt,"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_hetero_sampling_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_hetero_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_hetero_sampling_empty_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_hetero_sampling_empty_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_hetero_sampling_empty_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_hetero_sampling_empty_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
=
use_graphbolt
,
return_eids
=
return_eids
,
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
...
...
@@ -1071,19 +1193,29 @@ def test_rpc_hetero_etype_sampling_empty_shuffle(num_server):
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_sampling_empty_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_bipartite_sampling_empty_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_sampling_empty
(
Path
(
tmpdirname
),
num_server
)
check_rpc_bipartite_sampling_empty
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
,
return_eids
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
def
test_rpc_bipartite_sampling_shuffle
(
num_server
):
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"return_eids"
,
[
False
,
True
])
def
test_rpc_bipartite_sampling_shuffle
(
num_server
,
use_graphbolt
,
return_eids
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
check_rpc_bipartite_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
)
check_rpc_bipartite_sampling_shuffle
(
Path
(
tmpdirname
),
num_server
,
use_graphbolt
,
return_eids
)
@
pytest
.
mark
.
parametrize
(
"num_server"
,
[
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment