Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
df6b3250
Unverified
Commit
df6b3250
authored
Feb 01, 2024
by
Rhett Ying
Committed by
GitHub
Feb 01, 2024
Browse files
[DistGB] enable DistGraph to load graphbolt partitions (#7048)
parent
942b17ab
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
189 additions
and
62 deletions
+189
-62
python/dgl/distributed/dist_graph.py
python/dgl/distributed/dist_graph.py
+30
-10
tests/distributed/test_dist_graph_store.py
tests/distributed/test_dist_graph_store.py
+159
-52
No files found.
python/dgl/distributed/dist_graph.py
View file @
df6b3250
...
...
@@ -60,18 +60,21 @@ class InitGraphRequest(rpc.Request):
with shared memory.
"""
def
__init__
(
self
,
graph_name
):
def
__init__
(
self
,
graph_name
,
use_graphbolt
):
self
.
_graph_name
=
graph_name
self
.
_use_graphbolt
=
use_graphbolt
def
__getstate__
(
self
):
return
self
.
_graph_name
return
self
.
_graph_name
,
self
.
_use_graphbolt
def
__setstate__
(
self
,
state
):
self
.
_graph_name
=
state
self
.
_graph_name
,
self
.
_use_graphbolt
=
state
def
process_request
(
self
,
server_state
):
if
server_state
.
graph
is
None
:
server_state
.
graph
=
_get_graph_from_shared_mem
(
self
.
_graph_name
)
server_state
.
graph
=
_get_graph_from_shared_mem
(
self
.
_graph_name
,
self
.
_use_graphbolt
)
return
InitGraphResponse
(
self
.
_graph_name
)
...
...
@@ -153,13 +156,15 @@ def _exist_shared_mem_array(graph_name, name):
return
exist_shared_mem_array
(
_get_edata_path
(
graph_name
,
name
))
def
_get_graph_from_shared_mem
(
graph_name
):
def
_get_graph_from_shared_mem
(
graph_name
,
use_graphbolt
):
"""Get the graph from the DistGraph server.
The DistGraph server puts the graph structure of the local partition in the shared memory.
The client can access the graph structure and some metadata on nodes and edges directly
through shared memory to reduce the overhead of data access.
"""
if
use_graphbolt
:
return
gb
.
load_from_shared_memory
(
graph_name
)
g
,
ntypes
,
etypes
=
heterograph_index
.
create_heterograph_from_shared_memory
(
graph_name
)
...
...
@@ -524,6 +529,8 @@ class DistGraph:
part_config : str, optional
The path of partition configuration file generated by
:py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode.
use_graphbolt : bool, optional
Whether to load GraphBolt partition. Default: False.
Examples
--------
...
...
@@ -557,9 +564,15 @@ class DistGraph:
manually setting up servers and trainers. The setup is not fully tested yet.
"""
def
__init__
(
self
,
graph_name
,
gpb
=
None
,
part_config
=
None
):
def
__init__
(
self
,
graph_name
,
gpb
=
None
,
part_config
=
None
,
use_graphbolt
=
False
):
self
.
graph_name
=
graph_name
self
.
_use_graphbolt
=
use_graphbolt
if
os
.
environ
.
get
(
"DGL_DIST_MODE"
,
"standalone"
)
==
"standalone"
:
assert
(
use_graphbolt
is
False
),
"GraphBolt is not supported in standalone mode."
assert
(
part_config
is
not
None
),
"When running in the standalone model, the partition config file is required"
...
...
@@ -600,7 +613,9 @@ class DistGraph:
self
.
_init
(
gpb
)
# Tell the backup servers to load the graph structure from shared memory.
for
server_id
in
range
(
self
.
_client
.
num_servers
):
rpc
.
send_request
(
server_id
,
InitGraphRequest
(
graph_name
))
rpc
.
send_request
(
server_id
,
InitGraphRequest
(
graph_name
,
use_graphbolt
)
)
for
server_id
in
range
(
self
.
_client
.
num_servers
):
rpc
.
recv_response
()
self
.
_client
.
barrier
()
...
...
@@ -625,7 +640,9 @@ class DistGraph:
assert
(
self
.
_client
is
not
None
),
"Distributed module is not initialized. Please call dgl.distributed.initialize."
self
.
_g
=
_get_graph_from_shared_mem
(
self
.
graph_name
)
self
.
_g
=
_get_graph_from_shared_mem
(
self
.
graph_name
,
self
.
_use_graphbolt
)
self
.
_gpb
=
get_shared_mem_partition_book
(
self
.
graph_name
)
if
self
.
_gpb
is
None
:
self
.
_gpb
=
gpb
...
...
@@ -682,10 +699,10 @@ class DistGraph:
self
.
_edata_store
[
etype
]
=
data
def
__getstate__
(
self
):
return
self
.
graph_name
,
self
.
_gpb
return
self
.
graph_name
,
self
.
_gpb
,
self
.
_use_graphbolt
def
__setstate__
(
self
,
state
):
self
.
graph_name
,
gpb
=
state
self
.
graph_name
,
gpb
,
self
.
_use_graphbolt
=
state
self
.
_init
(
gpb
)
self
.
_init_ndata_store
()
...
...
@@ -1230,6 +1247,9 @@ class DistGraph:
tensor
The destination node ID array.
"""
assert
(
self
.
_use_graphbolt
is
False
),
"find_edges is not supported in GraphBolt."
if
etype
is
None
:
assert
(
len
(
self
.
etypes
)
==
1
...
...
tests/distributed/test_dist_graph_store.py
View file @
df6b3250
...
...
@@ -13,11 +13,13 @@ from multiprocessing import Condition, Manager, Process, Value
import
backend
as
F
import
dgl
import
dgl.graphbolt
as
gb
import
numpy
as
np
import
pytest
import
torch
as
th
from
dgl.data.utils
import
load_graphs
,
save_graphs
from
dgl.distributed
import
(
dgl_partition_to_graphbolt
,
DistEmbedding
,
DistGraph
,
DistGraphServer
,
...
...
@@ -38,12 +40,33 @@ if os.name != "nt":
import
struct
def
_verify_dist_graph_server_dgl
(
g
):
# verify dtype of underlying graph
cg
=
g
.
client_g
for
k
,
dtype
in
dgl
.
distributed
.
dist_graph
.
RESERVED_FIELD_DTYPE
.
items
():
if
k
in
cg
.
ndata
:
assert
(
F
.
dtype
(
cg
.
ndata
[
k
])
==
dtype
),
"Data type of {} in ndata should be {}."
.
format
(
k
,
dtype
)
if
k
in
cg
.
edata
:
assert
(
F
.
dtype
(
cg
.
edata
[
k
])
==
dtype
),
"Data type of {} in edata should be {}."
.
format
(
k
,
dtype
)
def
_verify_dist_graph_server_graphbolt
(
g
):
graph
=
g
.
client_g
assert
isinstance
(
graph
,
gb
.
FusedCSCSamplingGraph
)
# [Rui][TODO] verify dtype of underlying graph.
def
run_server
(
graph_name
,
server_id
,
server_count
,
num_clients
,
shared_mem
,
use_graphbolt
=
False
,
):
g
=
DistGraphServer
(
server_id
,
...
...
@@ -53,19 +76,15 @@ def run_server(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
disable_shared_mem
=
not
shared_mem
,
graph_format
=
[
"csc"
,
"coo"
],
use_graphbolt
=
use_graphbolt
,
)
print
(
"start server"
,
server_id
)
# verify dtype of underlying graph
cg
=
g
.
client_g
for
k
,
dtype
in
dgl
.
distributed
.
dist_graph
.
RESERVED_FIELD_DTYPE
.
items
():
if
k
in
cg
.
ndata
:
assert
(
F
.
dtype
(
cg
.
ndata
[
k
])
==
dtype
),
"Data type of {} in ndata should be {}."
.
format
(
k
,
dtype
)
if
k
in
cg
.
edata
:
assert
(
F
.
dtype
(
cg
.
edata
[
k
])
==
dtype
),
"Data type of {} in edata should be {}."
.
format
(
k
,
dtype
)
print
(
f
"Starting server[
{
server_id
}
] with use_graphbolt=
{
use_graphbolt
}
"
)
_verify
=
(
_verify_dist_graph_server_graphbolt
if
use_graphbolt
else
_verify_dist_graph_server_dgl
)
_verify
(
g
)
g
.
start
()
...
...
@@ -110,18 +129,26 @@ def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
def
run_client_empty
(
graph_name
,
part_id
,
server_count
,
num_clients
,
num_nodes
,
num_edges
graph_name
,
part_id
,
server_count
,
num_clients
,
num_nodes
,
num_edges
,
use_graphbolt
=
False
,
):
os
.
environ
[
"DGL_NUM_SERVER"
]
=
str
(
server_count
)
dgl
.
distributed
.
initialize
(
"kv_ip_config.txt"
)
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
check_dist_graph_empty
(
g
,
num_clients
,
num_nodes
,
num_edges
)
def
check_server_client_empty
(
shared_mem
,
num_servers
,
num_clients
):
def
check_server_client_empty
(
shared_mem
,
num_servers
,
num_clients
,
use_graphbolt
=
False
):
prepare_dist
(
num_servers
)
g
=
create_random_graph
(
10000
)
...
...
@@ -129,6 +156,9 @@ def check_server_client_empty(shared_mem, num_servers, num_clients):
num_parts
=
1
graph_name
=
"dist_graph_test_1"
partition_graph
(
g
,
graph_name
,
num_parts
,
"/tmp/dist_graph"
)
if
use_graphbolt
:
part_config
=
os
.
path
.
join
(
"/tmp/dist_graph"
,
f
"
{
graph_name
}
.json"
)
dgl_partition_to_graphbolt
(
part_config
)
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
...
...
@@ -137,7 +167,14 @@ def check_server_client_empty(shared_mem, num_servers, num_clients):
for
serv_id
in
range
(
num_servers
):
p
=
ctx
.
Process
(
target
=
run_server
,
args
=
(
graph_name
,
serv_id
,
num_servers
,
num_clients
,
shared_mem
),
args
=
(
graph_name
,
serv_id
,
num_servers
,
num_clients
,
shared_mem
,
use_graphbolt
,
),
)
serv_ps
.
append
(
p
)
p
.
start
()
...
...
@@ -154,6 +191,7 @@ def check_server_client_empty(shared_mem, num_servers, num_clients):
num_clients
,
g
.
num_nodes
(),
g
.
num_edges
(),
use_graphbolt
,
),
)
p
.
start
()
...
...
@@ -178,6 +216,7 @@ def run_client(
num_nodes
,
num_edges
,
group_id
,
use_graphbolt
=
False
,
):
os
.
environ
[
"DGL_NUM_SERVER"
]
=
str
(
server_count
)
os
.
environ
[
"DGL_GROUP_ID"
]
=
str
(
group_id
)
...
...
@@ -185,8 +224,10 @@ def run_client(
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
)
check_dist_graph
(
g
,
num_clients
,
num_nodes
,
num_edges
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
check_dist_graph
(
g
,
num_clients
,
num_nodes
,
num_edges
,
use_graphbolt
=
use_graphbolt
)
def
run_emb_client
(
...
...
@@ -270,14 +311,20 @@ def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
def
run_client_hierarchy
(
graph_name
,
part_id
,
server_count
,
node_mask
,
edge_mask
,
return_dict
graph_name
,
part_id
,
server_count
,
node_mask
,
edge_mask
,
return_dict
,
use_graphbolt
=
False
,
):
os
.
environ
[
"DGL_NUM_SERVER"
]
=
str
(
server_count
)
dgl
.
distributed
.
initialize
(
"kv_ip_config.txt"
)
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
node_mask
=
F
.
tensor
(
node_mask
)
edge_mask
=
F
.
tensor
(
edge_mask
)
nodes
=
node_split
(
...
...
@@ -355,7 +402,7 @@ def check_dist_emb(g, num_clients, num_nodes, num_edges):
sys
.
exit
(
-
1
)
def
check_dist_graph
(
g
,
num_clients
,
num_nodes
,
num_edges
):
def
check_dist_graph
(
g
,
num_clients
,
num_nodes
,
num_edges
,
use_graphbolt
=
False
):
# Test API
assert
g
.
num_nodes
()
==
num_nodes
assert
g
.
num_edges
()
==
num_edges
...
...
@@ -373,6 +420,12 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
assert
np
.
all
(
F
.
asnumpy
(
feats
==
eids
))
# Test edge_subgraph
if
use_graphbolt
:
with
pytest
.
raises
(
AssertionError
,
match
=
"find_edges is not supported in GraphBolt."
):
g
.
edge_subgraph
(
eids
)
else
:
sg
=
g
.
edge_subgraph
(
eids
)
assert
sg
.
num_edges
()
==
len
(
eids
)
assert
F
.
array_equal
(
sg
.
edata
[
dgl
.
EID
],
eids
)
...
...
@@ -522,7 +575,9 @@ def check_dist_emb_server_client(
print
(
"clients have terminated"
)
def
check_server_client
(
shared_mem
,
num_servers
,
num_clients
,
num_groups
=
1
):
def
check_server_client
(
shared_mem
,
num_servers
,
num_clients
,
num_groups
=
1
,
use_graphbolt
=
False
):
prepare_dist
(
num_servers
)
g
=
create_random_graph
(
10000
)
...
...
@@ -532,6 +587,9 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
g
.
ndata
[
"features"
]
=
F
.
unsqueeze
(
F
.
arange
(
0
,
g
.
num_nodes
()),
1
)
g
.
edata
[
"features"
]
=
F
.
unsqueeze
(
F
.
arange
(
0
,
g
.
num_edges
()),
1
)
partition_graph
(
g
,
graph_name
,
num_parts
,
"/tmp/dist_graph"
)
if
use_graphbolt
:
part_config
=
os
.
path
.
join
(
"/tmp/dist_graph"
,
f
"
{
graph_name
}
.json"
)
dgl_partition_to_graphbolt
(
part_config
)
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
...
...
@@ -546,6 +604,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
num_servers
,
num_clients
,
shared_mem
,
use_graphbolt
,
),
)
serv_ps
.
append
(
p
)
...
...
@@ -566,6 +625,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
g
.
num_nodes
(),
g
.
num_edges
(),
group_id
,
use_graphbolt
,
),
)
p
.
start
()
...
...
@@ -582,7 +642,12 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
print
(
"clients have terminated"
)
def
check_server_client_hierarchy
(
shared_mem
,
num_servers
,
num_clients
):
def
check_server_client_hierarchy
(
shared_mem
,
num_servers
,
num_clients
,
use_graphbolt
=
False
):
if
num_clients
==
1
:
# skip this test if there is only one client.
return
prepare_dist
(
num_servers
)
g
=
create_random_graph
(
10000
)
...
...
@@ -598,6 +663,9 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
"/tmp/dist_graph"
,
num_trainers_per_machine
=
num_clients
,
)
if
use_graphbolt
:
part_config
=
os
.
path
.
join
(
"/tmp/dist_graph"
,
f
"
{
graph_name
}
.json"
)
dgl_partition_to_graphbolt
(
part_config
)
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
...
...
@@ -606,7 +674,14 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
for
serv_id
in
range
(
num_servers
):
p
=
ctx
.
Process
(
target
=
run_server
,
args
=
(
graph_name
,
serv_id
,
num_servers
,
num_clients
,
shared_mem
),
args
=
(
graph_name
,
serv_id
,
num_servers
,
num_clients
,
shared_mem
,
use_graphbolt
,
),
)
serv_ps
.
append
(
p
)
p
.
start
()
...
...
@@ -633,6 +708,7 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
node_mask
,
edge_mask
,
return_dict
,
use_graphbolt
,
),
)
p
.
start
()
...
...
@@ -658,15 +734,23 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
def
run_client_hetero
(
graph_name
,
part_id
,
server_count
,
num_clients
,
num_nodes
,
num_edges
graph_name
,
part_id
,
server_count
,
num_clients
,
num_nodes
,
num_edges
,
use_graphbolt
=
False
,
):
os
.
environ
[
"DGL_NUM_SERVER"
]
=
str
(
server_count
)
dgl
.
distributed
.
initialize
(
"kv_ip_config.txt"
)
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
)
check_dist_graph_hetero
(
g
,
num_clients
,
num_nodes
,
num_edges
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
check_dist_graph_hetero
(
g
,
num_clients
,
num_nodes
,
num_edges
,
use_graphbolt
=
use_graphbolt
)
def
create_random_hetero
():
...
...
@@ -701,7 +785,9 @@ def create_random_hetero():
return
g
def
check_dist_graph_hetero
(
g
,
num_clients
,
num_nodes
,
num_edges
):
def
check_dist_graph_hetero
(
g
,
num_clients
,
num_nodes
,
num_edges
,
use_graphbolt
=
False
):
# Test API
for
ntype
in
num_nodes
:
assert
ntype
in
g
.
ntypes
...
...
@@ -754,6 +840,12 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
assert
expect_except
# Test edge_subgraph
if
use_graphbolt
:
with
pytest
.
raises
(
AssertionError
,
match
=
"find_edges is not supported in GraphBolt."
):
g
.
edge_subgraph
({
"r1"
:
eids
})
else
:
sg
=
g
.
edge_subgraph
({
"r1"
:
eids
})
assert
sg
.
num_edges
()
==
len
(
eids
)
assert
F
.
array_equal
(
sg
.
edata
[
dgl
.
EID
],
eids
)
...
...
@@ -827,7 +919,9 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
print
(
"end"
)
def
check_server_client_hetero
(
shared_mem
,
num_servers
,
num_clients
):
def
check_server_client_hetero
(
shared_mem
,
num_servers
,
num_clients
,
use_graphbolt
=
False
):
prepare_dist
(
num_servers
)
g
=
create_random_hetero
()
...
...
@@ -835,6 +929,9 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
num_parts
=
1
graph_name
=
"dist_graph_test_3"
partition_graph
(
g
,
graph_name
,
num_parts
,
"/tmp/dist_graph"
)
if
use_graphbolt
:
part_config
=
os
.
path
.
join
(
"/tmp/dist_graph"
,
f
"
{
graph_name
}
.json"
)
dgl_partition_to_graphbolt
(
part_config
)
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
...
...
@@ -843,7 +940,14 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
for
serv_id
in
range
(
num_servers
):
p
=
ctx
.
Process
(
target
=
run_server
,
args
=
(
graph_name
,
serv_id
,
num_servers
,
num_clients
,
shared_mem
),
args
=
(
graph_name
,
serv_id
,
num_servers
,
num_clients
,
shared_mem
,
use_graphbolt
,
),
)
serv_ps
.
append
(
p
)
p
.
start
()
...
...
@@ -862,6 +966,7 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
num_clients
,
num_nodes
,
num_edges
,
use_graphbolt
,
),
)
p
.
start
()
...
...
@@ -886,21 +991,23 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
@
unittest
.
skipIf
(
dgl
.
backend
.
backend_name
==
"mxnet"
,
reason
=
"Turn off Mxnet support"
)
def
test_server_client
():
@
pytest
.
mark
.
parametrize
(
"shared_mem"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"num_servers"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_clients"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"use_graphbolt"
,
[
True
,
False
])
def
test_server_client
(
shared_mem
,
num_servers
,
num_clients
,
use_graphbolt
):
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
check_server_client_hierarchy
(
False
,
1
,
4
)
check_server_client_empty
(
True
,
1
,
1
)
check_server_client_hetero
(
True
,
1
,
1
)
check_server_client_hetero
(
False
,
1
,
1
)
check_server_client
(
True
,
1
,
1
)
check_server_client
(
False
,
1
,
1
)
# [TODO][Rhett] Tests for multiple groups may fail sometimes and
# root cause is unknown. Let's disable them for now.
# check_server_client(True, 2, 2)
# check_server_client(True, 1, 1, 2)
# check_server_client(False, 1, 1, 2)
# check_server_client(True, 2, 2, 2)
# [Rui]
# 1. `disable_shared_mem=False` is not supported yet. Skip it.
# 2. `num_servers` > 1 does not work on single machine. Skip it.
for
func
in
[
check_server_client
,
check_server_client_hetero
,
check_server_client_empty
,
check_server_client_hierarchy
,
]:
func
(
shared_mem
,
num_servers
,
num_clients
,
use_graphbolt
=
use_graphbolt
)
@
unittest
.
skip
(
reason
=
"Skip due to glitch in CI"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment