Unverified Commit b0080d5b authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] add graphbolt flag into top level API (#7122)

parent 7a10bcb6
...@@ -19,3 +19,4 @@ from .rpc import * ...@@ -19,3 +19,4 @@ from .rpc import *
from .rpc_client import connect_to_server from .rpc_client import connect_to_server
from .rpc_server import start_server from .rpc_server import start_server
from .server_state import ServerState from .server_state import ServerState
from .constants import *
...@@ -7,3 +7,6 @@ SERVER_EXIT = "server_exit" ...@@ -7,3 +7,6 @@ SERVER_EXIT = "server_exit"
DEFAULT_NTYPE = "_N" DEFAULT_NTYPE = "_N"
DEFAULT_ETYPE = (DEFAULT_NTYPE, "_E", DEFAULT_NTYPE) DEFAULT_ETYPE = (DEFAULT_NTYPE, "_E", DEFAULT_NTYPE)
DATA_LOADING_BACKEND_DGL = "DGL"
DATA_LOADING_BACKEND_GRAPHBOLT = "GraphBolt"
...@@ -14,7 +14,11 @@ from enum import Enum ...@@ -14,7 +14,11 @@ from enum import Enum
from .. import utils from .. import utils
from ..base import dgl_warning, DGLError from ..base import dgl_warning, DGLError
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE from .constants import (
DATA_LOADING_BACKEND_DGL,
DATA_LOADING_BACKEND_GRAPHBOLT,
MAX_QUEUE_SIZE,
)
from .kvstore import close_kvstore, init_kvstore from .kvstore import close_kvstore, init_kvstore
from .role import init_role from .role import init_role
from .rpc_client import connect_to_server from .rpc_client import connect_to_server
...@@ -210,6 +214,7 @@ def initialize( ...@@ -210,6 +214,7 @@ def initialize(
max_queue_size=MAX_QUEUE_SIZE, max_queue_size=MAX_QUEUE_SIZE,
net_type=None, net_type=None,
num_worker_threads=1, num_worker_threads=1,
data_loading_backend=DATA_LOADING_BACKEND_DGL,
): ):
"""Initialize DGL's distributed module """Initialize DGL's distributed module
...@@ -231,6 +236,8 @@ def initialize( ...@@ -231,6 +236,8 @@ def initialize(
[Deprecated] Networking type, can be 'socket' only. [Deprecated] Networking type, can be 'socket' only.
num_worker_threads: int num_worker_threads: int
The number of OMP threads in each sampler process. The number of OMP threads in each sampler process.
data_loading_backend: str, optional
The backend for data loading. Can be 'DGL' or 'GraphBolt'.
Note Note
---- ----
...@@ -270,6 +277,8 @@ def initialize( ...@@ -270,6 +277,8 @@ def initialize(
int(os.environ.get("DGL_NUM_CLIENT")), int(os.environ.get("DGL_NUM_CLIENT")),
os.environ.get("DGL_CONF_PATH"), os.environ.get("DGL_CONF_PATH"),
graph_format=formats, graph_format=formats,
use_graphbolt=data_loading_backend
== DATA_LOADING_BACKEND_GRAPHBOLT,
) )
serv.start() serv.start()
sys.exit() sys.exit()
......
...@@ -18,6 +18,7 @@ from ..heterograph import DGLGraph ...@@ -18,6 +18,7 @@ from ..heterograph import DGLGraph
from ..ndarray import exist_shared_mem_array from ..ndarray import exist_shared_mem_array
from ..transforms import compact_graphs from ..transforms import compact_graphs
from . import graph_services, role, rpc from . import graph_services, role, rpc
from .constants import DATA_LOADING_BACKEND_DGL, DATA_LOADING_BACKEND_GRAPHBOLT
from .dist_tensor import DistTensor from .dist_tensor import DistTensor
from .graph_partition_book import ( from .graph_partition_book import (
_etype_str_to_tuple, _etype_str_to_tuple,
...@@ -50,6 +51,7 @@ from .shared_mem_utils import ( ...@@ -50,6 +51,7 @@ from .shared_mem_utils import (
) )
INIT_GRAPH = 800001 INIT_GRAPH = 800001
QUERY_DATA_LOADING_BACKEND = 800002
class InitGraphRequest(rpc.Request): class InitGraphRequest(rpc.Request):
...@@ -60,20 +62,19 @@ class InitGraphRequest(rpc.Request): ...@@ -60,20 +62,19 @@ class InitGraphRequest(rpc.Request):
with shared memory. with shared memory.
""" """
def __init__(self, graph_name, use_graphbolt): def __init__(self, graph_name):
self._graph_name = graph_name self._graph_name = graph_name
self._use_graphbolt = use_graphbolt
def __getstate__(self): def __getstate__(self):
return self._graph_name, self._use_graphbolt return self._graph_name
def __setstate__(self, state): def __setstate__(self, state):
self._graph_name, self._use_graphbolt = state self._graph_name = state
def process_request(self, server_state): def process_request(self, server_state):
if server_state.graph is None: if server_state.graph is None:
server_state.graph = _get_graph_from_shared_mem( server_state.graph = _get_graph_from_shared_mem(
self._graph_name, self._use_graphbolt self._graph_name, server_state.use_graphbolt
) )
return InitGraphResponse(self._graph_name) return InitGraphResponse(self._graph_name)
...@@ -91,6 +92,37 @@ class InitGraphResponse(rpc.Response): ...@@ -91,6 +92,37 @@ class InitGraphResponse(rpc.Response):
self._graph_name = state self._graph_name = state
class QueryDataLoadingBackendRequest(rpc.Request):
"""Query the data loading backend."""
def __getstate__(self):
return None
def __setstate__(self, state):
pass
def process_request(self, server_state):
backend = (
DATA_LOADING_BACKEND_GRAPHBOLT
if server_state.use_graphbolt
else DATA_LOADING_BACKEND_DGL
)
return QueryDataLoadingBackendResponse(backend)
class QueryDataLoadingBackendResponse(rpc.Response):
"""Ack the query data loading backend request"""
def __init__(self, backend):
self._backend = backend
def __getstate__(self):
return self._backend
def __setstate__(self, state):
self._backend = state
def _copy_graph_to_shared_mem(g, graph_name, graph_format, use_graphbolt): def _copy_graph_to_shared_mem(g, graph_name, graph_format, use_graphbolt):
if use_graphbolt: if use_graphbolt:
return g.copy_to_shared_memory(graph_name) return g.copy_to_shared_memory(graph_name)
...@@ -473,6 +505,7 @@ class DistGraphServer(KVServer): ...@@ -473,6 +505,7 @@ class DistGraphServer(KVServer):
kv_store=self, kv_store=self,
local_g=self.client_g, local_g=self.client_g,
partition_book=self.gpb, partition_book=self.gpb,
use_graphbolt=self.use_graphbolt,
) )
print( print(
"start graph service on server {} for part {}".format( "start graph service on server {} for part {}".format(
...@@ -529,8 +562,6 @@ class DistGraph: ...@@ -529,8 +562,6 @@ class DistGraph:
part_config : str, optional part_config : str, optional
The path of partition configuration file generated by The path of partition configuration file generated by
:py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode. :py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode.
use_graphbolt : bool, optional
Whether to load GraphBolt partition. Default: False.
Examples Examples
-------- --------
...@@ -564,15 +595,11 @@ class DistGraph: ...@@ -564,15 +595,11 @@ class DistGraph:
manually setting up servers and trainers. The setup is not fully tested yet. manually setting up servers and trainers. The setup is not fully tested yet.
""" """
def __init__( def __init__(self, graph_name, gpb=None, part_config=None):
self, graph_name, gpb=None, part_config=None, use_graphbolt=False
):
self.graph_name = graph_name self.graph_name = graph_name
self._use_graphbolt = use_graphbolt
if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone": if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone":
assert ( # "GraphBolt is not supported in standalone mode."
use_graphbolt is False self._use_graphbolt = False
), "GraphBolt is not supported in standalone mode."
assert ( assert (
part_config is not None part_config is not None
), "When running in the standalone model, the partition config file is required" ), "When running in the standalone model, the partition config file is required"
...@@ -610,12 +637,16 @@ class DistGraph: ...@@ -610,12 +637,16 @@ class DistGraph:
self._client.map_shared_data(self._gpb) self._client.map_shared_data(self._gpb)
rpc.set_num_client(1) rpc.set_num_client(1)
else: else:
# Query the main server about whether GraphBolt is used.
rpc.send_request(0, QueryDataLoadingBackendRequest())
self._use_graphbolt = (
rpc.recv_response()._backend == DATA_LOADING_BACKEND_GRAPHBOLT
)
self._init(gpb) self._init(gpb)
# Tell the backup servers to load the graph structure from shared memory. # Tell the backup servers to load the graph structure from shared memory.
for server_id in range(self._client.num_servers): for server_id in range(self._client.num_servers):
rpc.send_request( rpc.send_request(server_id, InitGraphRequest(graph_name))
server_id, InitGraphRequest(graph_name, use_graphbolt)
)
for server_id in range(self._client.num_servers): for server_id in range(self._client.num_servers):
rpc.recv_response() rpc.recv_response()
self._client.barrier() self._client.barrier()
...@@ -1832,3 +1863,8 @@ def edge_split( ...@@ -1832,3 +1863,8 @@ def edge_split(
rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse) rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse)
rpc.register_service(
QUERY_DATA_LOADING_BACKEND,
QueryDataLoadingBackendRequest,
QueryDataLoadingBackendResponse,
)
...@@ -38,13 +38,16 @@ class ServerState: ...@@ -38,13 +38,16 @@ class ServerState:
Total number of edges Total number of edges
partition_book : GraphPartitionBook partition_book : GraphPartitionBook
Graph Partition book Graph Partition book
use_graphbolt : bool
Whether to use graphbolt for dataloading.
""" """
def __init__(self, kv_store, local_g, partition_book): def __init__(self, kv_store, local_g, partition_book, use_graphbolt=False):
self._kv_store = kv_store self._kv_store = kv_store
self._graph = local_g self._graph = local_g
self.partition_book = partition_book self.partition_book = partition_book
self._roles = {} self._roles = {}
self._use_graphbolt = use_graphbolt
@property @property
def roles(self): def roles(self):
...@@ -69,5 +72,10 @@ class ServerState: ...@@ -69,5 +72,10 @@ class ServerState:
def graph(self, graph): def graph(self, graph):
self._graph = graph self._graph = graph
@property
def use_graphbolt(self):
"""Whether to use graphbolt for dataloading."""
return self._use_graphbolt
_init_api("dgl.distributed.server_state") _init_api("dgl.distributed.server_state")
...@@ -141,7 +141,7 @@ def run_client_empty( ...@@ -141,7 +141,7 @@ def run_client_empty(
gpb, graph_name, _, _ = load_partition_book( gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id "/tmp/dist_graph/{}.json".format(graph_name), part_id
) )
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt) g = DistGraph(graph_name, gpb=gpb)
check_dist_graph_empty(g, num_clients, num_nodes, num_edges) check_dist_graph_empty(g, num_clients, num_nodes, num_edges)
...@@ -222,7 +222,7 @@ def run_client( ...@@ -222,7 +222,7 @@ def run_client(
gpb, graph_name, _, _ = load_partition_book( gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id "/tmp/dist_graph/{}.json".format(graph_name), part_id
) )
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt) g = DistGraph(graph_name, gpb=gpb)
check_dist_graph( check_dist_graph(
g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
) )
...@@ -322,7 +322,7 @@ def run_client_hierarchy( ...@@ -322,7 +322,7 @@ def run_client_hierarchy(
gpb, graph_name, _, _ = load_partition_book( gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id "/tmp/dist_graph/{}.json".format(graph_name), part_id
) )
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt) g = DistGraph(graph_name, gpb=gpb)
node_mask = F.tensor(node_mask) node_mask = F.tensor(node_mask)
edge_mask = F.tensor(edge_mask) edge_mask = F.tensor(edge_mask)
nodes = node_split( nodes = node_split(
...@@ -742,7 +742,7 @@ def run_client_hetero( ...@@ -742,7 +742,7 @@ def run_client_hetero(
gpb, graph_name, _, _ = load_partition_book( gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id "/tmp/dist_graph/{}.json".format(graph_name), part_id
) )
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt) g = DistGraph(graph_name, gpb=gpb)
check_dist_graph_hetero( check_dist_graph_hetero(
g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
) )
......
...@@ -84,9 +84,7 @@ def start_sample_client_shuffle( ...@@ -84,9 +84,7 @@ def start_sample_client_shuffle(
tmpdir / "test_sampling.json", rank tmpdir / "test_sampling.json", rank
) )
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph( dist_graph = DistGraph("test_sampling", gpb=gpb)
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
sampled_graph = sample_neighbors( sampled_graph = sample_neighbors(
dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt
) )
...@@ -477,9 +475,7 @@ def start_hetero_sample_client( ...@@ -477,9 +475,7 @@ def start_hetero_sample_client(
tmpdir / "test_sampling.json", rank tmpdir / "test_sampling.json", rank
) )
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph( dist_graph = DistGraph("test_sampling", gpb=gpb)
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
assert "feat" in dist_graph.nodes["n1"].data assert "feat" in dist_graph.nodes["n1"].data
assert "feat" not in dist_graph.nodes["n2"].data assert "feat" not in dist_graph.nodes["n2"].data
assert "feat" not in dist_graph.nodes["n3"].data assert "feat" not in dist_graph.nodes["n3"].data
...@@ -517,9 +513,7 @@ def start_hetero_etype_sample_client( ...@@ -517,9 +513,7 @@ def start_hetero_etype_sample_client(
tmpdir / "test_sampling.json", rank tmpdir / "test_sampling.json", rank
) )
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph( dist_graph = DistGraph("test_sampling", gpb=gpb)
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
assert "feat" in dist_graph.nodes["n1"].data assert "feat" in dist_graph.nodes["n1"].data
assert "feat" not in dist_graph.nodes["n2"].data assert "feat" not in dist_graph.nodes["n2"].data
assert "feat" not in dist_graph.nodes["n3"].data assert "feat" not in dist_graph.nodes["n3"].data
...@@ -876,9 +870,7 @@ def start_bipartite_sample_client( ...@@ -876,9 +870,7 @@ def start_bipartite_sample_client(
tmpdir / "test_sampling.json", rank tmpdir / "test_sampling.json", rank
) )
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph( dist_graph = DistGraph("test_sampling", gpb=gpb)
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
assert "feat" in dist_graph.nodes["user"].data assert "feat" in dist_graph.nodes["user"].data
assert "feat" in dist_graph.nodes["game"].data assert "feat" in dist_graph.nodes["game"].data
if gpb is None: if gpb is None:
...@@ -911,9 +903,7 @@ def start_bipartite_etype_sample_client( ...@@ -911,9 +903,7 @@ def start_bipartite_etype_sample_client(
tmpdir / "test_sampling.json", rank tmpdir / "test_sampling.json", rank
) )
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph( dist_graph = DistGraph("test_sampling", gpb=gpb)
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
assert "feat" in dist_graph.nodes["user"].data assert "feat" in dist_graph.nodes["user"].data
assert "feat" in dist_graph.nodes["game"].data assert "feat" in dist_graph.nodes["game"].data
......
...@@ -11,6 +11,8 @@ import pytest ...@@ -11,6 +11,8 @@ import pytest
import torch as th import torch as th
from dgl.data import CitationGraphDataset from dgl.data import CitationGraphDataset
from dgl.distributed import ( from dgl.distributed import (
DATA_LOADING_BACKEND_DGL,
DATA_LOADING_BACKEND_GRAPHBOLT,
DistDataLoader, DistDataLoader,
DistGraph, DistGraph,
DistGraphServer, DistGraphServer,
...@@ -104,7 +106,6 @@ def start_dist_dataloader( ...@@ -104,7 +106,6 @@ def start_dist_dataloader(
"test_sampling", "test_sampling",
gpb=gpb, gpb=gpb,
part_config=part_config, part_config=part_config,
use_graphbolt=use_graphbolt,
) )
# Create sampler # Create sampler
...@@ -443,7 +444,6 @@ def start_node_dataloader( ...@@ -443,7 +444,6 @@ def start_node_dataloader(
"test_sampling", "test_sampling",
gpb=gpb, gpb=gpb,
part_config=part_config, part_config=part_config,
use_graphbolt=use_graphbolt,
) )
assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes) assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)
assert len(dist_graph.etypes) == len(groundtruth_g.etypes) assert len(dist_graph.etypes) == len(groundtruth_g.etypes)
...@@ -763,9 +763,7 @@ def start_multiple_dataloaders( ...@@ -763,9 +763,7 @@ def start_multiple_dataloaders(
use_graphbolt, use_graphbolt,
): ):
dgl.distributed.initialize(ip_config) dgl.distributed.initialize(ip_config)
dist_g = dgl.distributed.DistGraph( dist_g = dgl.distributed.DistGraph(graph_name, part_config=part_config)
graph_name, part_config=part_config, use_graphbolt=use_graphbolt
)
if dataloader_type == "node": if dataloader_type == "node":
train_ids = th.arange(orig_g.num_nodes()) train_ids = th.arange(orig_g.num_nodes())
batch_size = orig_g.num_nodes() // 100 batch_size = orig_g.num_nodes() // 100
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment