Unverified Commit ebf505ee authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] update top level API about use_graphbolt (#7169)

parent 7b4c8c77
...@@ -7,6 +7,3 @@ SERVER_EXIT = "server_exit" ...@@ -7,6 +7,3 @@ SERVER_EXIT = "server_exit"
DEFAULT_NTYPE = "_N" DEFAULT_NTYPE = "_N"
DEFAULT_ETYPE = (DEFAULT_NTYPE, "_E", DEFAULT_NTYPE) DEFAULT_ETYPE = (DEFAULT_NTYPE, "_E", DEFAULT_NTYPE)
DATA_LOADING_BACKEND_DGL = "DGL"
DATA_LOADING_BACKEND_GRAPHBOLT = "GraphBolt"
...@@ -14,11 +14,7 @@ from enum import Enum ...@@ -14,11 +14,7 @@ from enum import Enum
from .. import utils from .. import utils
from ..base import dgl_warning, DGLError from ..base import dgl_warning, DGLError
from . import rpc from . import rpc
from .constants import ( from .constants import MAX_QUEUE_SIZE
DATA_LOADING_BACKEND_DGL,
DATA_LOADING_BACKEND_GRAPHBOLT,
MAX_QUEUE_SIZE,
)
from .kvstore import close_kvstore, init_kvstore from .kvstore import close_kvstore, init_kvstore
from .role import init_role from .role import init_role
from .rpc_client import connect_to_server from .rpc_client import connect_to_server
...@@ -214,7 +210,7 @@ def initialize( ...@@ -214,7 +210,7 @@ def initialize(
max_queue_size=MAX_QUEUE_SIZE, max_queue_size=MAX_QUEUE_SIZE,
net_type=None, net_type=None,
num_worker_threads=1, num_worker_threads=1,
data_loading_backend=DATA_LOADING_BACKEND_DGL, use_graphbolt=False,
): ):
"""Initialize DGL's distributed module """Initialize DGL's distributed module
...@@ -236,8 +232,8 @@ def initialize( ...@@ -236,8 +232,8 @@ def initialize(
[Deprecated] Networking type, can be 'socket' only. [Deprecated] Networking type, can be 'socket' only.
num_worker_threads: int num_worker_threads: int
The number of OMP threads in each sampler process. The number of OMP threads in each sampler process.
data_loading_backend: str, optional use_graphbolt: bool, optional
The backend for data loading. Can be 'DGL' or 'GraphBolt'. Whether to use GraphBolt for distributed train.
Note Note
---- ----
...@@ -277,8 +273,7 @@ def initialize( ...@@ -277,8 +273,7 @@ def initialize(
int(os.environ.get("DGL_NUM_CLIENT")), int(os.environ.get("DGL_NUM_CLIENT")),
os.environ.get("DGL_CONF_PATH"), os.environ.get("DGL_CONF_PATH"),
graph_format=formats, graph_format=formats,
use_graphbolt=data_loading_backend use_graphbolt=use_graphbolt,
== DATA_LOADING_BACKEND_GRAPHBOLT,
) )
serv.start() serv.start()
sys.exit() sys.exit()
......
...@@ -18,7 +18,6 @@ from ..heterograph import DGLGraph ...@@ -18,7 +18,6 @@ from ..heterograph import DGLGraph
from ..ndarray import exist_shared_mem_array from ..ndarray import exist_shared_mem_array
from ..transforms import compact_graphs from ..transforms import compact_graphs
from . import graph_services, role, rpc from . import graph_services, role, rpc
from .constants import DATA_LOADING_BACKEND_DGL, DATA_LOADING_BACKEND_GRAPHBOLT
from .dist_tensor import DistTensor from .dist_tensor import DistTensor
from .graph_partition_book import ( from .graph_partition_book import (
_etype_str_to_tuple, _etype_str_to_tuple,
...@@ -51,7 +50,7 @@ from .shared_mem_utils import ( ...@@ -51,7 +50,7 @@ from .shared_mem_utils import (
) )
INIT_GRAPH = 800001 INIT_GRAPH = 800001
QUERY_DATA_LOADING_BACKEND = 800002 QUERY_IF_USE_GRAPHBOLT = 800002
class InitGraphRequest(rpc.Request): class InitGraphRequest(rpc.Request):
...@@ -92,8 +91,8 @@ class InitGraphResponse(rpc.Response): ...@@ -92,8 +91,8 @@ class InitGraphResponse(rpc.Response):
self._graph_name = state self._graph_name = state
class QueryDataLoadingBackendRequest(rpc.Request): class QueryIfUseGraphBoltRequest(rpc.Request):
"""Query the data loading backend.""" """Query if use GraphBolt."""
def __getstate__(self): def __getstate__(self):
return None return None
...@@ -102,25 +101,20 @@ class QueryDataLoadingBackendRequest(rpc.Request): ...@@ -102,25 +101,20 @@ class QueryDataLoadingBackendRequest(rpc.Request):
pass pass
def process_request(self, server_state): def process_request(self, server_state):
backend = ( return QueryIfUseGraphBoltResponse(server_state.use_graphbolt)
DATA_LOADING_BACKEND_GRAPHBOLT
if server_state.use_graphbolt
else DATA_LOADING_BACKEND_DGL
)
return QueryDataLoadingBackendResponse(backend)
class QueryDataLoadingBackendResponse(rpc.Response): class QueryIfUseGraphBoltResponse(rpc.Response):
"""Ack the query data loading backend request""" """Ack the query request about if use GraphBolt."""
def __init__(self, backend): def __init__(self, use_graphbolt):
self._backend = backend self._use_graphbolt = use_graphbolt
def __getstate__(self): def __getstate__(self):
return self._backend return self._use_graphbolt
def __setstate__(self, state): def __setstate__(self, state):
self._backend = state self._use_graphbolt = state
def _copy_graph_to_shared_mem(g, graph_name, graph_format, use_graphbolt): def _copy_graph_to_shared_mem(g, graph_name, graph_format, use_graphbolt):
...@@ -638,10 +632,8 @@ class DistGraph: ...@@ -638,10 +632,8 @@ class DistGraph:
rpc.set_num_client(1) rpc.set_num_client(1)
else: else:
# Query the main server about whether GraphBolt is used. # Query the main server about whether GraphBolt is used.
rpc.send_request(0, QueryDataLoadingBackendRequest()) rpc.send_request(0, QueryIfUseGraphBoltRequest())
self._use_graphbolt = ( self._use_graphbolt = rpc.recv_response()._use_graphbolt
rpc.recv_response()._backend == DATA_LOADING_BACKEND_GRAPHBOLT
)
self._init(gpb) self._init(gpb)
# Tell the backup servers to load the graph structure from shared memory. # Tell the backup servers to load the graph structure from shared memory.
...@@ -1869,7 +1861,7 @@ def edge_split( ...@@ -1869,7 +1861,7 @@ def edge_split(
rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse) rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse)
rpc.register_service( rpc.register_service(
QUERY_DATA_LOADING_BACKEND, QUERY_IF_USE_GRAPHBOLT,
QueryDataLoadingBackendRequest, QueryIfUseGraphBoltRequest,
QueryDataLoadingBackendResponse, QueryIfUseGraphBoltResponse,
) )
...@@ -11,8 +11,6 @@ import pytest ...@@ -11,8 +11,6 @@ import pytest
import torch as th import torch as th
from dgl.data import CitationGraphDataset from dgl.data import CitationGraphDataset
from dgl.distributed import ( from dgl.distributed import (
DATA_LOADING_BACKEND_DGL,
DATA_LOADING_BACKEND_GRAPHBOLT,
DistDataLoader, DistDataLoader,
DistGraph, DistGraph,
DistGraphServer, DistGraphServer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment