Unverified Commit ebf505ee authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] update top level API about use_graphbolt (#7169)

parent 7b4c8c77
......@@ -7,6 +7,3 @@ SERVER_EXIT = "server_exit"
DEFAULT_NTYPE = "_N"
DEFAULT_ETYPE = (DEFAULT_NTYPE, "_E", DEFAULT_NTYPE)
DATA_LOADING_BACKEND_DGL = "DGL"
DATA_LOADING_BACKEND_GRAPHBOLT = "GraphBolt"
......@@ -14,11 +14,7 @@ from enum import Enum
from .. import utils
from ..base import dgl_warning, DGLError
from . import rpc
from .constants import (
DATA_LOADING_BACKEND_DGL,
DATA_LOADING_BACKEND_GRAPHBOLT,
MAX_QUEUE_SIZE,
)
from .constants import MAX_QUEUE_SIZE
from .kvstore import close_kvstore, init_kvstore
from .role import init_role
from .rpc_client import connect_to_server
......@@ -214,7 +210,7 @@ def initialize(
max_queue_size=MAX_QUEUE_SIZE,
net_type=None,
num_worker_threads=1,
data_loading_backend=DATA_LOADING_BACKEND_DGL,
use_graphbolt=False,
):
"""Initialize DGL's distributed module
......@@ -236,8 +232,8 @@ def initialize(
[Deprecated] Networking type, can be 'socket' only.
num_worker_threads: int
The number of OMP threads in each sampler process.
data_loading_backend: str, optional
The backend for data loading. Can be 'DGL' or 'GraphBolt'.
use_graphbolt: bool, optional
Whether to use GraphBolt for distributed train.
Note
----
......@@ -277,8 +273,7 @@ def initialize(
int(os.environ.get("DGL_NUM_CLIENT")),
os.environ.get("DGL_CONF_PATH"),
graph_format=formats,
use_graphbolt=data_loading_backend
== DATA_LOADING_BACKEND_GRAPHBOLT,
use_graphbolt=use_graphbolt,
)
serv.start()
sys.exit()
......
......@@ -18,7 +18,6 @@ from ..heterograph import DGLGraph
from ..ndarray import exist_shared_mem_array
from ..transforms import compact_graphs
from . import graph_services, role, rpc
from .constants import DATA_LOADING_BACKEND_DGL, DATA_LOADING_BACKEND_GRAPHBOLT
from .dist_tensor import DistTensor
from .graph_partition_book import (
_etype_str_to_tuple,
......@@ -51,7 +50,7 @@ from .shared_mem_utils import (
)
INIT_GRAPH = 800001
QUERY_DATA_LOADING_BACKEND = 800002
QUERY_IF_USE_GRAPHBOLT = 800002
class InitGraphRequest(rpc.Request):
......@@ -92,8 +91,8 @@ class InitGraphResponse(rpc.Response):
self._graph_name = state
class QueryDataLoadingBackendRequest(rpc.Request):
"""Query the data loading backend."""
class QueryIfUseGraphBoltRequest(rpc.Request):
"""Query if use GraphBolt."""
def __getstate__(self):
return None
......@@ -102,25 +101,20 @@ class QueryDataLoadingBackendRequest(rpc.Request):
pass
def process_request(self, server_state):
backend = (
DATA_LOADING_BACKEND_GRAPHBOLT
if server_state.use_graphbolt
else DATA_LOADING_BACKEND_DGL
)
return QueryDataLoadingBackendResponse(backend)
return QueryIfUseGraphBoltResponse(server_state.use_graphbolt)
class QueryDataLoadingBackendResponse(rpc.Response):
"""Ack the query data loading backend request"""
class QueryIfUseGraphBoltResponse(rpc.Response):
"""Ack the query request about if use GraphBolt."""
def __init__(self, backend):
self._backend = backend
def __init__(self, use_graphbolt):
self._use_graphbolt = use_graphbolt
def __getstate__(self):
return self._backend
return self._use_graphbolt
def __setstate__(self, state):
self._backend = state
self._use_graphbolt = state
def _copy_graph_to_shared_mem(g, graph_name, graph_format, use_graphbolt):
......@@ -638,10 +632,8 @@ class DistGraph:
rpc.set_num_client(1)
else:
# Query the main server about whether GraphBolt is used.
rpc.send_request(0, QueryDataLoadingBackendRequest())
self._use_graphbolt = (
rpc.recv_response()._backend == DATA_LOADING_BACKEND_GRAPHBOLT
)
rpc.send_request(0, QueryIfUseGraphBoltRequest())
self._use_graphbolt = rpc.recv_response()._use_graphbolt
self._init(gpb)
# Tell the backup servers to load the graph structure from shared memory.
......@@ -1869,7 +1861,7 @@ def edge_split(
rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse)
rpc.register_service(
QUERY_DATA_LOADING_BACKEND,
QueryDataLoadingBackendRequest,
QueryDataLoadingBackendResponse,
QUERY_IF_USE_GRAPHBOLT,
QueryIfUseGraphBoltRequest,
QueryIfUseGraphBoltResponse,
)
......@@ -11,8 +11,6 @@ import pytest
import torch as th
from dgl.data import CitationGraphDataset
from dgl.distributed import (
DATA_LOADING_BACKEND_DGL,
DATA_LOADING_BACKEND_GRAPHBOLT,
DistDataLoader,
DistGraph,
DistGraphServer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment