Unverified Commit b0080d5b authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] add graphbolt flag into top level API (#7122)

parent 7a10bcb6
......@@ -19,3 +19,4 @@ from .rpc import *
from .rpc_client import connect_to_server
from .rpc_server import start_server
from .server_state import ServerState
from .constants import *
......@@ -7,3 +7,6 @@ SERVER_EXIT = "server_exit"
DEFAULT_NTYPE = "_N"
DEFAULT_ETYPE = (DEFAULT_NTYPE, "_E", DEFAULT_NTYPE)
DATA_LOADING_BACKEND_DGL = "DGL"
DATA_LOADING_BACKEND_GRAPHBOLT = "GraphBolt"
......@@ -14,7 +14,11 @@ from enum import Enum
from .. import utils
from ..base import dgl_warning, DGLError
from . import rpc
from .constants import MAX_QUEUE_SIZE
from .constants import (
DATA_LOADING_BACKEND_DGL,
DATA_LOADING_BACKEND_GRAPHBOLT,
MAX_QUEUE_SIZE,
)
from .kvstore import close_kvstore, init_kvstore
from .role import init_role
from .rpc_client import connect_to_server
......@@ -210,6 +214,7 @@ def initialize(
max_queue_size=MAX_QUEUE_SIZE,
net_type=None,
num_worker_threads=1,
data_loading_backend=DATA_LOADING_BACKEND_DGL,
):
"""Initialize DGL's distributed module
......@@ -231,6 +236,8 @@ def initialize(
[Deprecated] Networking type, can be 'socket' only.
num_worker_threads: int
The number of OMP threads in each sampler process.
data_loading_backend: str, optional
The backend for data loading. Can be 'DGL' or 'GraphBolt'.
Note
----
......@@ -270,6 +277,8 @@ def initialize(
int(os.environ.get("DGL_NUM_CLIENT")),
os.environ.get("DGL_CONF_PATH"),
graph_format=formats,
use_graphbolt=data_loading_backend
== DATA_LOADING_BACKEND_GRAPHBOLT,
)
serv.start()
sys.exit()
......
......@@ -18,6 +18,7 @@ from ..heterograph import DGLGraph
from ..ndarray import exist_shared_mem_array
from ..transforms import compact_graphs
from . import graph_services, role, rpc
from .constants import DATA_LOADING_BACKEND_DGL, DATA_LOADING_BACKEND_GRAPHBOLT
from .dist_tensor import DistTensor
from .graph_partition_book import (
_etype_str_to_tuple,
......@@ -50,6 +51,7 @@ from .shared_mem_utils import (
)
INIT_GRAPH = 800001
QUERY_DATA_LOADING_BACKEND = 800002
class InitGraphRequest(rpc.Request):
......@@ -60,20 +62,19 @@ class InitGraphRequest(rpc.Request):
with shared memory.
"""
def __init__(self, graph_name, use_graphbolt):
def __init__(self, graph_name):
self._graph_name = graph_name
self._use_graphbolt = use_graphbolt
def __getstate__(self):
return self._graph_name, self._use_graphbolt
return self._graph_name
def __setstate__(self, state):
self._graph_name, self._use_graphbolt = state
self._graph_name = state
def process_request(self, server_state):
if server_state.graph is None:
server_state.graph = _get_graph_from_shared_mem(
self._graph_name, self._use_graphbolt
self._graph_name, server_state.use_graphbolt
)
return InitGraphResponse(self._graph_name)
......@@ -91,6 +92,37 @@ class InitGraphResponse(rpc.Response):
self._graph_name = state
class QueryDataLoadingBackendRequest(rpc.Request):
"""Query the data loading backend."""
def __getstate__(self):
return None
def __setstate__(self, state):
pass
def process_request(self, server_state):
backend = (
DATA_LOADING_BACKEND_GRAPHBOLT
if server_state.use_graphbolt
else DATA_LOADING_BACKEND_DGL
)
return QueryDataLoadingBackendResponse(backend)
class QueryDataLoadingBackendResponse(rpc.Response):
"""Ack the query data loading backend request"""
def __init__(self, backend):
self._backend = backend
def __getstate__(self):
return self._backend
def __setstate__(self, state):
self._backend = state
def _copy_graph_to_shared_mem(g, graph_name, graph_format, use_graphbolt):
if use_graphbolt:
return g.copy_to_shared_memory(graph_name)
......@@ -473,6 +505,7 @@ class DistGraphServer(KVServer):
kv_store=self,
local_g=self.client_g,
partition_book=self.gpb,
use_graphbolt=self.use_graphbolt,
)
print(
"start graph service on server {} for part {}".format(
......@@ -529,8 +562,6 @@ class DistGraph:
part_config : str, optional
The path of partition configuration file generated by
:py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode.
use_graphbolt : bool, optional
Whether to load GraphBolt partition. Default: False.
Examples
--------
......@@ -564,15 +595,11 @@ class DistGraph:
manually setting up servers and trainers. The setup is not fully tested yet.
"""
def __init__(
self, graph_name, gpb=None, part_config=None, use_graphbolt=False
):
def __init__(self, graph_name, gpb=None, part_config=None):
self.graph_name = graph_name
self._use_graphbolt = use_graphbolt
if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone":
assert (
use_graphbolt is False
), "GraphBolt is not supported in standalone mode."
# "GraphBolt is not supported in standalone mode."
self._use_graphbolt = False
assert (
part_config is not None
), "When running in the standalone model, the partition config file is required"
......@@ -610,12 +637,16 @@ class DistGraph:
self._client.map_shared_data(self._gpb)
rpc.set_num_client(1)
else:
# Query the main server about whether GraphBolt is used.
rpc.send_request(0, QueryDataLoadingBackendRequest())
self._use_graphbolt = (
rpc.recv_response()._backend == DATA_LOADING_BACKEND_GRAPHBOLT
)
self._init(gpb)
# Tell the backup servers to load the graph structure from shared memory.
for server_id in range(self._client.num_servers):
rpc.send_request(
server_id, InitGraphRequest(graph_name, use_graphbolt)
)
rpc.send_request(server_id, InitGraphRequest(graph_name))
for server_id in range(self._client.num_servers):
rpc.recv_response()
self._client.barrier()
......@@ -1832,3 +1863,8 @@ def edge_split(
rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse)
rpc.register_service(
QUERY_DATA_LOADING_BACKEND,
QueryDataLoadingBackendRequest,
QueryDataLoadingBackendResponse,
)
......@@ -38,13 +38,16 @@ class ServerState:
Total number of edges
partition_book : GraphPartitionBook
Graph Partition book
use_graphbolt : bool
Whether to use graphbolt for dataloading.
"""
def __init__(self, kv_store, local_g, partition_book):
def __init__(self, kv_store, local_g, partition_book, use_graphbolt=False):
self._kv_store = kv_store
self._graph = local_g
self.partition_book = partition_book
self._roles = {}
self._use_graphbolt = use_graphbolt
@property
def roles(self):
......@@ -69,5 +72,10 @@ class ServerState:
def graph(self, graph):
self._graph = graph
@property
def use_graphbolt(self):
"""Whether to use graphbolt for dataloading."""
return self._use_graphbolt
_init_api("dgl.distributed.server_state")
......@@ -141,7 +141,7 @@ def run_client_empty(
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id
)
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
g = DistGraph(graph_name, gpb=gpb)
check_dist_graph_empty(g, num_clients, num_nodes, num_edges)
......@@ -222,7 +222,7 @@ def run_client(
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id
)
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
g = DistGraph(graph_name, gpb=gpb)
check_dist_graph(
g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
)
......@@ -322,7 +322,7 @@ def run_client_hierarchy(
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id
)
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
g = DistGraph(graph_name, gpb=gpb)
node_mask = F.tensor(node_mask)
edge_mask = F.tensor(edge_mask)
nodes = node_split(
......@@ -742,7 +742,7 @@ def run_client_hetero(
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id
)
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
g = DistGraph(graph_name, gpb=gpb)
check_dist_graph_hetero(
g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
)
......
......@@ -84,9 +84,7 @@ def start_sample_client_shuffle(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
dist_graph = DistGraph("test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(
dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt
)
......@@ -477,9 +475,7 @@ def start_hetero_sample_client(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
dist_graph = DistGraph("test_sampling", gpb=gpb)
assert "feat" in dist_graph.nodes["n1"].data
assert "feat" not in dist_graph.nodes["n2"].data
assert "feat" not in dist_graph.nodes["n3"].data
......@@ -517,9 +513,7 @@ def start_hetero_etype_sample_client(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
dist_graph = DistGraph("test_sampling", gpb=gpb)
assert "feat" in dist_graph.nodes["n1"].data
assert "feat" not in dist_graph.nodes["n2"].data
assert "feat" not in dist_graph.nodes["n3"].data
......@@ -876,9 +870,7 @@ def start_bipartite_sample_client(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
dist_graph = DistGraph("test_sampling", gpb=gpb)
assert "feat" in dist_graph.nodes["user"].data
assert "feat" in dist_graph.nodes["game"].data
if gpb is None:
......@@ -911,9 +903,7 @@ def start_bipartite_etype_sample_client(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
dist_graph = DistGraph("test_sampling", gpb=gpb)
assert "feat" in dist_graph.nodes["user"].data
assert "feat" in dist_graph.nodes["game"].data
......
......@@ -11,6 +11,8 @@ import pytest
import torch as th
from dgl.data import CitationGraphDataset
from dgl.distributed import (
DATA_LOADING_BACKEND_DGL,
DATA_LOADING_BACKEND_GRAPHBOLT,
DistDataLoader,
DistGraph,
DistGraphServer,
......@@ -104,7 +106,6 @@ def start_dist_dataloader(
"test_sampling",
gpb=gpb,
part_config=part_config,
use_graphbolt=use_graphbolt,
)
# Create sampler
......@@ -443,7 +444,6 @@ def start_node_dataloader(
"test_sampling",
gpb=gpb,
part_config=part_config,
use_graphbolt=use_graphbolt,
)
assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)
assert len(dist_graph.etypes) == len(groundtruth_g.etypes)
......@@ -763,9 +763,7 @@ def start_multiple_dataloaders(
use_graphbolt,
):
dgl.distributed.initialize(ip_config)
dist_g = dgl.distributed.DistGraph(
graph_name, part_config=part_config, use_graphbolt=use_graphbolt
)
dist_g = dgl.distributed.DistGraph(graph_name, part_config=part_config)
if dataloader_type == "node":
train_ids = th.arange(orig_g.num_nodes())
batch_size = orig_g.num_nodes() // 100
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment