Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ebf505ee
Unverified
Commit
ebf505ee
authored
Feb 28, 2024
by
Rhett Ying
Committed by
GitHub
Feb 28, 2024
Browse files
[DistGB] update top level API about use_graphbolt (#7169)
parent
7b4c8c77
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
38 deletions
+20
-38
python/dgl/distributed/constants.py
python/dgl/distributed/constants.py
+0
-3
python/dgl/distributed/dist_context.py
python/dgl/distributed/dist_context.py
+5
-10
python/dgl/distributed/dist_graph.py
python/dgl/distributed/dist_graph.py
+15
-23
tests/distributed/test_mp_dataloader.py
tests/distributed/test_mp_dataloader.py
+0
-2
No files found.
python/dgl/distributed/constants.py
View file @
ebf505ee
...
@@ -7,6 +7,3 @@ SERVER_EXIT = "server_exit"
...
@@ -7,6 +7,3 @@ SERVER_EXIT = "server_exit"
DEFAULT_NTYPE
=
"_N"
DEFAULT_NTYPE
=
"_N"
DEFAULT_ETYPE
=
(
DEFAULT_NTYPE
,
"_E"
,
DEFAULT_NTYPE
)
DEFAULT_ETYPE
=
(
DEFAULT_NTYPE
,
"_E"
,
DEFAULT_NTYPE
)
DATA_LOADING_BACKEND_DGL
=
"DGL"
DATA_LOADING_BACKEND_GRAPHBOLT
=
"GraphBolt"
python/dgl/distributed/dist_context.py
View file @
ebf505ee
...
@@ -14,11 +14,7 @@ from enum import Enum
...
@@ -14,11 +14,7 @@ from enum import Enum
from
..
import
utils
from
..
import
utils
from
..base
import
dgl_warning
,
DGLError
from
..base
import
dgl_warning
,
DGLError
from
.
import
rpc
from
.
import
rpc
from
.constants
import
(
from
.constants
import
MAX_QUEUE_SIZE
DATA_LOADING_BACKEND_DGL
,
DATA_LOADING_BACKEND_GRAPHBOLT
,
MAX_QUEUE_SIZE
,
)
from
.kvstore
import
close_kvstore
,
init_kvstore
from
.kvstore
import
close_kvstore
,
init_kvstore
from
.role
import
init_role
from
.role
import
init_role
from
.rpc_client
import
connect_to_server
from
.rpc_client
import
connect_to_server
...
@@ -214,7 +210,7 @@ def initialize(
...
@@ -214,7 +210,7 @@ def initialize(
max_queue_size
=
MAX_QUEUE_SIZE
,
max_queue_size
=
MAX_QUEUE_SIZE
,
net_type
=
None
,
net_type
=
None
,
num_worker_threads
=
1
,
num_worker_threads
=
1
,
data_loading_backend
=
DATA_LOADING_BACKEND_DGL
,
use_graphbolt
=
False
,
):
):
"""Initialize DGL's distributed module
"""Initialize DGL's distributed module
...
@@ -236,8 +232,8 @@ def initialize(
...
@@ -236,8 +232,8 @@ def initialize(
[Deprecated] Networking type, can be 'socket' only.
[Deprecated] Networking type, can be 'socket' only.
num_worker_threads: int
num_worker_threads: int
The number of OMP threads in each sampler process.
The number of OMP threads in each sampler process.
data_loading_backend: str
, optional
use_graphbolt: bool
, optional
T
he
backend for data loading. Can be 'DGL' or 'GraphBolt'
.
W
he
ther to use GraphBolt for distributed train
.
Note
Note
----
----
...
@@ -277,8 +273,7 @@ def initialize(
...
@@ -277,8 +273,7 @@ def initialize(
int
(
os
.
environ
.
get
(
"DGL_NUM_CLIENT"
)),
int
(
os
.
environ
.
get
(
"DGL_NUM_CLIENT"
)),
os
.
environ
.
get
(
"DGL_CONF_PATH"
),
os
.
environ
.
get
(
"DGL_CONF_PATH"
),
graph_format
=
formats
,
graph_format
=
formats
,
use_graphbolt
=
data_loading_backend
use_graphbolt
=
use_graphbolt
,
==
DATA_LOADING_BACKEND_GRAPHBOLT
,
)
)
serv
.
start
()
serv
.
start
()
sys
.
exit
()
sys
.
exit
()
...
...
python/dgl/distributed/dist_graph.py
View file @
ebf505ee
...
@@ -18,7 +18,6 @@ from ..heterograph import DGLGraph
...
@@ -18,7 +18,6 @@ from ..heterograph import DGLGraph
from
..ndarray
import
exist_shared_mem_array
from
..ndarray
import
exist_shared_mem_array
from
..transforms
import
compact_graphs
from
..transforms
import
compact_graphs
from
.
import
graph_services
,
role
,
rpc
from
.
import
graph_services
,
role
,
rpc
from
.constants
import
DATA_LOADING_BACKEND_DGL
,
DATA_LOADING_BACKEND_GRAPHBOLT
from
.dist_tensor
import
DistTensor
from
.dist_tensor
import
DistTensor
from
.graph_partition_book
import
(
from
.graph_partition_book
import
(
_etype_str_to_tuple
,
_etype_str_to_tuple
,
...
@@ -51,7 +50,7 @@ from .shared_mem_utils import (
...
@@ -51,7 +50,7 @@ from .shared_mem_utils import (
)
)
INIT_GRAPH
=
800001
INIT_GRAPH
=
800001
QUERY_
DATA_LOADING_BACKEND
=
800002
QUERY_
IF_USE_GRAPHBOLT
=
800002
class
InitGraphRequest
(
rpc
.
Request
):
class
InitGraphRequest
(
rpc
.
Request
):
...
@@ -92,8 +91,8 @@ class InitGraphResponse(rpc.Response):
...
@@ -92,8 +91,8 @@ class InitGraphResponse(rpc.Response):
self
.
_graph_name
=
state
self
.
_graph_name
=
state
class
Query
DataLoadingBackend
Request
(
rpc
.
Request
):
class
Query
IfUseGraphBolt
Request
(
rpc
.
Request
):
"""Query
the data loading backend
."""
"""Query
if use GraphBolt
."""
def
__getstate__
(
self
):
def
__getstate__
(
self
):
return
None
return
None
...
@@ -102,25 +101,20 @@ class QueryDataLoadingBackendRequest(rpc.Request):
...
@@ -102,25 +101,20 @@ class QueryDataLoadingBackendRequest(rpc.Request):
pass
pass
def
process_request
(
self
,
server_state
):
def
process_request
(
self
,
server_state
):
backend
=
(
return
QueryIfUseGraphBoltResponse
(
server_state
.
use_graphbolt
)
DATA_LOADING_BACKEND_GRAPHBOLT
if
server_state
.
use_graphbolt
else
DATA_LOADING_BACKEND_DGL
)
return
QueryDataLoadingBackendResponse
(
backend
)
class
Query
DataLoadingBackend
Response
(
rpc
.
Response
):
class
Query
IfUseGraphBolt
Response
(
rpc
.
Response
):
"""Ack the query
data loading backend request
"""
"""Ack the query
request about if use GraphBolt.
"""
def
__init__
(
self
,
backend
):
def
__init__
(
self
,
use_graphbolt
):
self
.
_
backend
=
backend
self
.
_
use_graphbolt
=
use_graphbolt
def
__getstate__
(
self
):
def
__getstate__
(
self
):
return
self
.
_
backend
return
self
.
_
use_graphbolt
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
self
.
_
backend
=
state
self
.
_
use_graphbolt
=
state
def
_copy_graph_to_shared_mem
(
g
,
graph_name
,
graph_format
,
use_graphbolt
):
def
_copy_graph_to_shared_mem
(
g
,
graph_name
,
graph_format
,
use_graphbolt
):
...
@@ -638,10 +632,8 @@ class DistGraph:
...
@@ -638,10 +632,8 @@ class DistGraph:
rpc
.
set_num_client
(
1
)
rpc
.
set_num_client
(
1
)
else
:
else
:
# Query the main server about whether GraphBolt is used.
# Query the main server about whether GraphBolt is used.
rpc
.
send_request
(
0
,
QueryDataLoadingBackendRequest
())
rpc
.
send_request
(
0
,
QueryIfUseGraphBoltRequest
())
self
.
_use_graphbolt
=
(
self
.
_use_graphbolt
=
rpc
.
recv_response
().
_use_graphbolt
rpc
.
recv_response
().
_backend
==
DATA_LOADING_BACKEND_GRAPHBOLT
)
self
.
_init
(
gpb
)
self
.
_init
(
gpb
)
# Tell the backup servers to load the graph structure from shared memory.
# Tell the backup servers to load the graph structure from shared memory.
...
@@ -1869,7 +1861,7 @@ def edge_split(
...
@@ -1869,7 +1861,7 @@ def edge_split(
rpc
.
register_service
(
INIT_GRAPH
,
InitGraphRequest
,
InitGraphResponse
)
rpc
.
register_service
(
INIT_GRAPH
,
InitGraphRequest
,
InitGraphResponse
)
rpc
.
register_service
(
rpc
.
register_service
(
QUERY_
DATA_LOADING_BACKEND
,
QUERY_
IF_USE_GRAPHBOLT
,
Query
DataLoadingBackend
Request
,
Query
IfUseGraphBolt
Request
,
Query
DataLoadingBackend
Response
,
Query
IfUseGraphBolt
Response
,
)
)
tests/distributed/test_mp_dataloader.py
View file @
ebf505ee
...
@@ -11,8 +11,6 @@ import pytest
...
@@ -11,8 +11,6 @@ import pytest
import
torch
as
th
import
torch
as
th
from
dgl.data
import
CitationGraphDataset
from
dgl.data
import
CitationGraphDataset
from
dgl.distributed
import
(
from
dgl.distributed
import
(
DATA_LOADING_BACKEND_DGL
,
DATA_LOADING_BACKEND_GRAPHBOLT
,
DistDataLoader
,
DistDataLoader
,
DistGraph
,
DistGraph
,
DistGraphServer
,
DistGraphServer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment