Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b0080d5b
Unverified
Commit
b0080d5b
authored
Feb 22, 2024
by
Rhett Ying
Committed by
GitHub
Feb 22, 2024
Browse files
[DistGB] add graphbolt flag into top level API (#7122)
parent
7a10bcb6
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
88 additions
and
43 deletions
+88
-43
python/dgl/distributed/__init__.py
python/dgl/distributed/__init__.py
+1
-0
python/dgl/distributed/constants.py
python/dgl/distributed/constants.py
+3
-0
python/dgl/distributed/dist_context.py
python/dgl/distributed/dist_context.py
+10
-1
python/dgl/distributed/dist_graph.py
python/dgl/distributed/dist_graph.py
+53
-17
python/dgl/distributed/server_state.py
python/dgl/distributed/server_state.py
+9
-1
tests/distributed/test_dist_graph_store.py
tests/distributed/test_dist_graph_store.py
+4
-4
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+5
-15
tests/distributed/test_mp_dataloader.py
tests/distributed/test_mp_dataloader.py
+3
-5
No files found.
python/dgl/distributed/__init__.py
View file @
b0080d5b
...
...
@@ -19,3 +19,4 @@ from .rpc import *
from
.rpc_client
import
connect_to_server
from
.rpc_server
import
start_server
from
.server_state
import
ServerState
from
.constants
import
*
python/dgl/distributed/constants.py
View file @
b0080d5b
...
...
@@ -7,3 +7,6 @@ SERVER_EXIT = "server_exit"
DEFAULT_NTYPE
=
"_N"
DEFAULT_ETYPE
=
(
DEFAULT_NTYPE
,
"_E"
,
DEFAULT_NTYPE
)
DATA_LOADING_BACKEND_DGL
=
"DGL"
DATA_LOADING_BACKEND_GRAPHBOLT
=
"GraphBolt"
python/dgl/distributed/dist_context.py
View file @
b0080d5b
...
...
@@ -14,7 +14,11 @@ from enum import Enum
from
..
import
utils
from
..base
import
dgl_warning
,
DGLError
from
.
import
rpc
from
.constants
import
MAX_QUEUE_SIZE
from
.constants
import
(
DATA_LOADING_BACKEND_DGL
,
DATA_LOADING_BACKEND_GRAPHBOLT
,
MAX_QUEUE_SIZE
,
)
from
.kvstore
import
close_kvstore
,
init_kvstore
from
.role
import
init_role
from
.rpc_client
import
connect_to_server
...
...
@@ -210,6 +214,7 @@ def initialize(
max_queue_size
=
MAX_QUEUE_SIZE
,
net_type
=
None
,
num_worker_threads
=
1
,
data_loading_backend
=
DATA_LOADING_BACKEND_DGL
,
):
"""Initialize DGL's distributed module
...
...
@@ -231,6 +236,8 @@ def initialize(
[Deprecated] Networking type, can be 'socket' only.
num_worker_threads: int
The number of OMP threads in each sampler process.
data_loading_backend: str, optional
The backend for data loading. Can be 'DGL' or 'GraphBolt'.
Note
----
...
...
@@ -270,6 +277,8 @@ def initialize(
int
(
os
.
environ
.
get
(
"DGL_NUM_CLIENT"
)),
os
.
environ
.
get
(
"DGL_CONF_PATH"
),
graph_format
=
formats
,
use_graphbolt
=
data_loading_backend
==
DATA_LOADING_BACKEND_GRAPHBOLT
,
)
serv
.
start
()
sys
.
exit
()
...
...
python/dgl/distributed/dist_graph.py
View file @
b0080d5b
...
...
@@ -18,6 +18,7 @@ from ..heterograph import DGLGraph
from
..ndarray
import
exist_shared_mem_array
from
..transforms
import
compact_graphs
from
.
import
graph_services
,
role
,
rpc
from
.constants
import
DATA_LOADING_BACKEND_DGL
,
DATA_LOADING_BACKEND_GRAPHBOLT
from
.dist_tensor
import
DistTensor
from
.graph_partition_book
import
(
_etype_str_to_tuple
,
...
...
@@ -50,6 +51,7 @@ from .shared_mem_utils import (
)
INIT_GRAPH
=
800001
QUERY_DATA_LOADING_BACKEND
=
800002
class
InitGraphRequest
(
rpc
.
Request
):
...
...
@@ -60,20 +62,19 @@ class InitGraphRequest(rpc.Request):
with shared memory.
"""
def
__init__
(
self
,
graph_name
,
use_graphbolt
):
def
__init__
(
self
,
graph_name
):
self
.
_graph_name
=
graph_name
self
.
_use_graphbolt
=
use_graphbolt
def
__getstate__
(
self
):
return
self
.
_graph_name
,
self
.
_use_graphbolt
return
self
.
_graph_name
def
__setstate__
(
self
,
state
):
self
.
_graph_name
,
self
.
_use_graphbolt
=
state
self
.
_graph_name
=
state
def
process_request
(
self
,
server_state
):
if
server_state
.
graph
is
None
:
server_state
.
graph
=
_get_graph_from_shared_mem
(
self
.
_graph_name
,
se
lf
.
_
use_graphbolt
self
.
_graph_name
,
se
rver_state
.
use_graphbolt
)
return
InitGraphResponse
(
self
.
_graph_name
)
...
...
@@ -91,6 +92,37 @@ class InitGraphResponse(rpc.Response):
self
.
_graph_name
=
state
class
QueryDataLoadingBackendRequest
(
rpc
.
Request
):
"""Query the data loading backend."""
def
__getstate__
(
self
):
return
None
def
__setstate__
(
self
,
state
):
pass
def
process_request
(
self
,
server_state
):
backend
=
(
DATA_LOADING_BACKEND_GRAPHBOLT
if
server_state
.
use_graphbolt
else
DATA_LOADING_BACKEND_DGL
)
return
QueryDataLoadingBackendResponse
(
backend
)
class
QueryDataLoadingBackendResponse
(
rpc
.
Response
):
"""Ack the query data loading backend request"""
def
__init__
(
self
,
backend
):
self
.
_backend
=
backend
def
__getstate__
(
self
):
return
self
.
_backend
def
__setstate__
(
self
,
state
):
self
.
_backend
=
state
def
_copy_graph_to_shared_mem
(
g
,
graph_name
,
graph_format
,
use_graphbolt
):
if
use_graphbolt
:
return
g
.
copy_to_shared_memory
(
graph_name
)
...
...
@@ -473,6 +505,7 @@ class DistGraphServer(KVServer):
kv_store
=
self
,
local_g
=
self
.
client_g
,
partition_book
=
self
.
gpb
,
use_graphbolt
=
self
.
use_graphbolt
,
)
print
(
"start graph service on server {} for part {}"
.
format
(
...
...
@@ -529,8 +562,6 @@ class DistGraph:
part_config : str, optional
The path of partition configuration file generated by
:py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode.
use_graphbolt : bool, optional
Whether to load GraphBolt partition. Default: False.
Examples
--------
...
...
@@ -564,15 +595,11 @@ class DistGraph:
manually setting up servers and trainers. The setup is not fully tested yet.
"""
def
__init__
(
self
,
graph_name
,
gpb
=
None
,
part_config
=
None
,
use_graphbolt
=
False
):
def
__init__
(
self
,
graph_name
,
gpb
=
None
,
part_config
=
None
):
self
.
graph_name
=
graph_name
self
.
_use_graphbolt
=
use_graphbolt
if
os
.
environ
.
get
(
"DGL_DIST_MODE"
,
"standalone"
)
==
"standalone"
:
assert
(
use_graphbolt
is
False
),
"GraphBolt is not supported in standalone mode."
# "GraphBolt is not supported in standalone mode."
self
.
_use_graphbolt
=
False
assert
(
part_config
is
not
None
),
"When running in the standalone model, the partition config file is required"
...
...
@@ -610,12 +637,16 @@ class DistGraph:
self
.
_client
.
map_shared_data
(
self
.
_gpb
)
rpc
.
set_num_client
(
1
)
else
:
# Query the main server about whether GraphBolt is used.
rpc
.
send_request
(
0
,
QueryDataLoadingBackendRequest
())
self
.
_use_graphbolt
=
(
rpc
.
recv_response
().
_backend
==
DATA_LOADING_BACKEND_GRAPHBOLT
)
self
.
_init
(
gpb
)
# Tell the backup servers to load the graph structure from shared memory.
for
server_id
in
range
(
self
.
_client
.
num_servers
):
rpc
.
send_request
(
server_id
,
InitGraphRequest
(
graph_name
,
use_graphbolt
)
)
rpc
.
send_request
(
server_id
,
InitGraphRequest
(
graph_name
))
for
server_id
in
range
(
self
.
_client
.
num_servers
):
rpc
.
recv_response
()
self
.
_client
.
barrier
()
...
...
@@ -1832,3 +1863,8 @@ def edge_split(
rpc
.
register_service
(
INIT_GRAPH
,
InitGraphRequest
,
InitGraphResponse
)
rpc
.
register_service
(
QUERY_DATA_LOADING_BACKEND
,
QueryDataLoadingBackendRequest
,
QueryDataLoadingBackendResponse
,
)
python/dgl/distributed/server_state.py
View file @
b0080d5b
...
...
@@ -38,13 +38,16 @@ class ServerState:
Total number of edges
partition_book : GraphPartitionBook
Graph Partition book
use_graphbolt : bool
Whether to use graphbolt for dataloading.
"""
def
__init__
(
self
,
kv_store
,
local_g
,
partition_book
):
def
__init__
(
self
,
kv_store
,
local_g
,
partition_book
,
use_graphbolt
=
False
):
self
.
_kv_store
=
kv_store
self
.
_graph
=
local_g
self
.
partition_book
=
partition_book
self
.
_roles
=
{}
self
.
_use_graphbolt
=
use_graphbolt
@
property
def
roles
(
self
):
...
...
@@ -69,5 +72,10 @@ class ServerState:
def
graph
(
self
,
graph
):
self
.
_graph
=
graph
@
property
def
use_graphbolt
(
self
):
"""Whether to use graphbolt for dataloading."""
return
self
.
_use_graphbolt
_init_api
(
"dgl.distributed.server_state"
)
tests/distributed/test_dist_graph_store.py
View file @
b0080d5b
...
...
@@ -141,7 +141,7 @@ def run_client_empty(
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
)
check_dist_graph_empty
(
g
,
num_clients
,
num_nodes
,
num_edges
)
...
...
@@ -222,7 +222,7 @@ def run_client(
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
)
check_dist_graph
(
g
,
num_clients
,
num_nodes
,
num_edges
,
use_graphbolt
=
use_graphbolt
)
...
...
@@ -322,7 +322,7 @@ def run_client_hierarchy(
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
)
node_mask
=
F
.
tensor
(
node_mask
)
edge_mask
=
F
.
tensor
(
edge_mask
)
nodes
=
node_split
(
...
...
@@ -742,7 +742,7 @@ def run_client_hetero(
gpb
,
graph_name
,
_
,
_
=
load_partition_book
(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
part_id
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
g
=
DistGraph
(
graph_name
,
gpb
=
gpb
)
check_dist_graph_hetero
(
g
,
num_clients
,
num_nodes
,
num_edges
,
use_graphbolt
=
use_graphbolt
)
...
...
tests/distributed/test_distributed_sampling.py
View file @
b0080d5b
...
...
@@ -84,9 +84,7 @@ def start_sample_client_shuffle(
tmpdir
/
"test_sampling.json"
,
rank
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
sampled_graph
=
sample_neighbors
(
dist_graph
,
[
0
,
10
,
99
,
66
,
1024
,
2008
],
3
,
use_graphbolt
=
use_graphbolt
)
...
...
@@ -477,9 +475,7 @@ def start_hetero_sample_client(
tmpdir
/
"test_sampling.json"
,
rank
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
...
...
@@ -517,9 +513,7 @@ def start_hetero_etype_sample_client(
tmpdir
/
"test_sampling.json"
,
rank
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
assert
"feat"
in
dist_graph
.
nodes
[
"n1"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n2"
].
data
assert
"feat"
not
in
dist_graph
.
nodes
[
"n3"
].
data
...
...
@@ -876,9 +870,7 @@ def start_bipartite_sample_client(
tmpdir
/
"test_sampling.json"
,
rank
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
if
gpb
is
None
:
...
...
@@ -911,9 +903,7 @@ def start_bipartite_etype_sample_client(
tmpdir
/
"test_sampling.json"
,
rank
)
dgl
.
distributed
.
initialize
(
"rpc_ip_config.txt"
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
,
use_graphbolt
=
use_graphbolt
)
dist_graph
=
DistGraph
(
"test_sampling"
,
gpb
=
gpb
)
assert
"feat"
in
dist_graph
.
nodes
[
"user"
].
data
assert
"feat"
in
dist_graph
.
nodes
[
"game"
].
data
...
...
tests/distributed/test_mp_dataloader.py
View file @
b0080d5b
...
...
@@ -11,6 +11,8 @@ import pytest
import
torch
as
th
from
dgl.data
import
CitationGraphDataset
from
dgl.distributed
import
(
DATA_LOADING_BACKEND_DGL
,
DATA_LOADING_BACKEND_GRAPHBOLT
,
DistDataLoader
,
DistGraph
,
DistGraphServer
,
...
...
@@ -104,7 +106,6 @@ def start_dist_dataloader(
"test_sampling"
,
gpb
=
gpb
,
part_config
=
part_config
,
use_graphbolt
=
use_graphbolt
,
)
# Create sampler
...
...
@@ -443,7 +444,6 @@ def start_node_dataloader(
"test_sampling"
,
gpb
=
gpb
,
part_config
=
part_config
,
use_graphbolt
=
use_graphbolt
,
)
assert
len
(
dist_graph
.
ntypes
)
==
len
(
groundtruth_g
.
ntypes
)
assert
len
(
dist_graph
.
etypes
)
==
len
(
groundtruth_g
.
etypes
)
...
...
@@ -763,9 +763,7 @@ def start_multiple_dataloaders(
use_graphbolt
,
):
dgl
.
distributed
.
initialize
(
ip_config
)
dist_g
=
dgl
.
distributed
.
DistGraph
(
graph_name
,
part_config
=
part_config
,
use_graphbolt
=
use_graphbolt
)
dist_g
=
dgl
.
distributed
.
DistGraph
(
graph_name
,
part_config
=
part_config
)
if
dataloader_type
==
"node"
:
train_ids
=
th
.
arange
(
orig_g
.
num_nodes
())
batch_size
=
orig_g
.
num_nodes
()
//
100
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment