Unverified Commit df6b3250 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] enable DistGraph to load graphbolt partitions (#7048)

parent 942b17ab
......@@ -60,18 +60,21 @@ class InitGraphRequest(rpc.Request):
with shared memory.
"""
def __init__(self, graph_name):
def __init__(self, graph_name, use_graphbolt):
self._graph_name = graph_name
self._use_graphbolt = use_graphbolt
def __getstate__(self):
return self._graph_name
return self._graph_name, self._use_graphbolt
def __setstate__(self, state):
self._graph_name = state
self._graph_name, self._use_graphbolt = state
def process_request(self, server_state):
if server_state.graph is None:
server_state.graph = _get_graph_from_shared_mem(self._graph_name)
server_state.graph = _get_graph_from_shared_mem(
self._graph_name, self._use_graphbolt
)
return InitGraphResponse(self._graph_name)
......@@ -153,13 +156,15 @@ def _exist_shared_mem_array(graph_name, name):
return exist_shared_mem_array(_get_edata_path(graph_name, name))
def _get_graph_from_shared_mem(graph_name):
def _get_graph_from_shared_mem(graph_name, use_graphbolt):
"""Get the graph from the DistGraph server.
The DistGraph server puts the graph structure of the local partition in the shared memory.
The client can access the graph structure and some metadata on nodes and edges directly
through shared memory to reduce the overhead of data access.
"""
if use_graphbolt:
return gb.load_from_shared_memory(graph_name)
g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(
graph_name
)
......@@ -524,6 +529,8 @@ class DistGraph:
part_config : str, optional
The path of partition configuration file generated by
:py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode.
use_graphbolt : bool, optional
Whether to load GraphBolt partition. Default: False.
Examples
--------
......@@ -557,9 +564,15 @@ class DistGraph:
manually setting up servers and trainers. The setup is not fully tested yet.
"""
def __init__(self, graph_name, gpb=None, part_config=None):
def __init__(
self, graph_name, gpb=None, part_config=None, use_graphbolt=False
):
self.graph_name = graph_name
self._use_graphbolt = use_graphbolt
if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone":
assert (
use_graphbolt is False
), "GraphBolt is not supported in standalone mode."
assert (
part_config is not None
), "When running in the standalone model, the partition config file is required"
......@@ -600,7 +613,9 @@ class DistGraph:
self._init(gpb)
# Tell the backup servers to load the graph structure from shared memory.
for server_id in range(self._client.num_servers):
rpc.send_request(server_id, InitGraphRequest(graph_name))
rpc.send_request(
server_id, InitGraphRequest(graph_name, use_graphbolt)
)
for server_id in range(self._client.num_servers):
rpc.recv_response()
self._client.barrier()
......@@ -625,7 +640,9 @@ class DistGraph:
assert (
self._client is not None
), "Distributed module is not initialized. Please call dgl.distributed.initialize."
self._g = _get_graph_from_shared_mem(self.graph_name)
self._g = _get_graph_from_shared_mem(
self.graph_name, self._use_graphbolt
)
self._gpb = get_shared_mem_partition_book(self.graph_name)
if self._gpb is None:
self._gpb = gpb
......@@ -682,10 +699,10 @@ class DistGraph:
self._edata_store[etype] = data
def __getstate__(self):
return self.graph_name, self._gpb
return self.graph_name, self._gpb, self._use_graphbolt
def __setstate__(self, state):
self.graph_name, gpb = state
self.graph_name, gpb, self._use_graphbolt = state
self._init(gpb)
self._init_ndata_store()
......@@ -1230,6 +1247,9 @@ class DistGraph:
tensor
The destination node ID array.
"""
assert (
self._use_graphbolt is False
), "find_edges is not supported in GraphBolt."
if etype is None:
assert (
len(self.etypes) == 1
......
......@@ -13,11 +13,13 @@ from multiprocessing import Condition, Manager, Process, Value
import backend as F
import dgl
import dgl.graphbolt as gb
import numpy as np
import pytest
import torch as th
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import (
dgl_partition_to_graphbolt,
DistEmbedding,
DistGraph,
DistGraphServer,
......@@ -38,12 +40,33 @@ if os.name != "nt":
import struct
def _verify_dist_graph_server_dgl(g):
# verify dtype of underlying graph
cg = g.client_g
for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():
if k in cg.ndata:
assert (
F.dtype(cg.ndata[k]) == dtype
), "Data type of {} in ndata should be {}.".format(k, dtype)
if k in cg.edata:
assert (
F.dtype(cg.edata[k]) == dtype
), "Data type of {} in edata should be {}.".format(k, dtype)
def _verify_dist_graph_server_graphbolt(g):
graph = g.client_g
assert isinstance(graph, gb.FusedCSCSamplingGraph)
# [Rui][TODO] verify dtype of underlying graph.
def run_server(
graph_name,
server_id,
server_count,
num_clients,
shared_mem,
use_graphbolt=False,
):
g = DistGraphServer(
server_id,
......@@ -53,19 +76,15 @@ def run_server(
"/tmp/dist_graph/{}.json".format(graph_name),
disable_shared_mem=not shared_mem,
graph_format=["csc", "coo"],
use_graphbolt=use_graphbolt,
)
print("start server", server_id)
# verify dtype of underlying graph
cg = g.client_g
for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():
if k in cg.ndata:
assert (
F.dtype(cg.ndata[k]) == dtype
), "Data type of {} in ndata should be {}.".format(k, dtype)
if k in cg.edata:
assert (
F.dtype(cg.edata[k]) == dtype
), "Data type of {} in edata should be {}.".format(k, dtype)
print(f"Starting server[{server_id}] with use_graphbolt={use_graphbolt}")
_verify = (
_verify_dist_graph_server_graphbolt
if use_graphbolt
else _verify_dist_graph_server_dgl
)
_verify(g)
g.start()
......@@ -110,18 +129,26 @@ def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
def run_client_empty(
graph_name, part_id, server_count, num_clients, num_nodes, num_edges
graph_name,
part_id,
server_count,
num_clients,
num_nodes,
num_edges,
use_graphbolt=False,
):
os.environ["DGL_NUM_SERVER"] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id
)
g = DistGraph(graph_name, gpb=gpb)
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
check_dist_graph_empty(g, num_clients, num_nodes, num_edges)
def check_server_client_empty(shared_mem, num_servers, num_clients):
def check_server_client_empty(
shared_mem, num_servers, num_clients, use_graphbolt=False
):
prepare_dist(num_servers)
g = create_random_graph(10000)
......@@ -129,6 +156,9 @@ def check_server_client_empty(shared_mem, num_servers, num_clients):
num_parts = 1
graph_name = "dist_graph_test_1"
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
if use_graphbolt:
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
dgl_partition_to_graphbolt(part_config)
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
......@@ -137,7 +167,14 @@ def check_server_client_empty(shared_mem, num_servers, num_clients):
for serv_id in range(num_servers):
p = ctx.Process(
target=run_server,
args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
args=(
graph_name,
serv_id,
num_servers,
num_clients,
shared_mem,
use_graphbolt,
),
)
serv_ps.append(p)
p.start()
......@@ -154,6 +191,7 @@ def check_server_client_empty(shared_mem, num_servers, num_clients):
num_clients,
g.num_nodes(),
g.num_edges(),
use_graphbolt,
),
)
p.start()
......@@ -178,6 +216,7 @@ def run_client(
num_nodes,
num_edges,
group_id,
use_graphbolt=False,
):
os.environ["DGL_NUM_SERVER"] = str(server_count)
os.environ["DGL_GROUP_ID"] = str(group_id)
......@@ -185,8 +224,10 @@ def run_client(
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id
)
g = DistGraph(graph_name, gpb=gpb)
check_dist_graph(g, num_clients, num_nodes, num_edges)
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
check_dist_graph(
g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
)
def run_emb_client(
......@@ -270,14 +311,20 @@ def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
def run_client_hierarchy(
graph_name, part_id, server_count, node_mask, edge_mask, return_dict
graph_name,
part_id,
server_count,
node_mask,
edge_mask,
return_dict,
use_graphbolt=False,
):
os.environ["DGL_NUM_SERVER"] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id
)
g = DistGraph(graph_name, gpb=gpb)
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
node_mask = F.tensor(node_mask)
edge_mask = F.tensor(edge_mask)
nodes = node_split(
......@@ -355,7 +402,7 @@ def check_dist_emb(g, num_clients, num_nodes, num_edges):
sys.exit(-1)
def check_dist_graph(g, num_clients, num_nodes, num_edges):
def check_dist_graph(g, num_clients, num_nodes, num_edges, use_graphbolt=False):
# Test API
assert g.num_nodes() == num_nodes
assert g.num_edges() == num_edges
......@@ -373,6 +420,12 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
assert np.all(F.asnumpy(feats == eids))
# Test edge_subgraph
if use_graphbolt:
with pytest.raises(
AssertionError, match="find_edges is not supported in GraphBolt."
):
g.edge_subgraph(eids)
else:
sg = g.edge_subgraph(eids)
assert sg.num_edges() == len(eids)
assert F.array_equal(sg.edata[dgl.EID], eids)
......@@ -522,7 +575,9 @@ def check_dist_emb_server_client(
print("clients have terminated")
def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
def check_server_client(
shared_mem, num_servers, num_clients, num_groups=1, use_graphbolt=False
):
prepare_dist(num_servers)
g = create_random_graph(10000)
......@@ -532,6 +587,9 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
if use_graphbolt:
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
dgl_partition_to_graphbolt(part_config)
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
......@@ -546,6 +604,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
num_servers,
num_clients,
shared_mem,
use_graphbolt,
),
)
serv_ps.append(p)
......@@ -566,6 +625,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
g.num_nodes(),
g.num_edges(),
group_id,
use_graphbolt,
),
)
p.start()
......@@ -582,7 +642,12 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
print("clients have terminated")
def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
def check_server_client_hierarchy(
shared_mem, num_servers, num_clients, use_graphbolt=False
):
if num_clients == 1:
# skip this test if there is only one client.
return
prepare_dist(num_servers)
g = create_random_graph(10000)
......@@ -598,6 +663,9 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
"/tmp/dist_graph",
num_trainers_per_machine=num_clients,
)
if use_graphbolt:
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
dgl_partition_to_graphbolt(part_config)
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
......@@ -606,7 +674,14 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
for serv_id in range(num_servers):
p = ctx.Process(
target=run_server,
args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
args=(
graph_name,
serv_id,
num_servers,
num_clients,
shared_mem,
use_graphbolt,
),
)
serv_ps.append(p)
p.start()
......@@ -633,6 +708,7 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
node_mask,
edge_mask,
return_dict,
use_graphbolt,
),
)
p.start()
......@@ -658,15 +734,23 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
def run_client_hetero(
graph_name, part_id, server_count, num_clients, num_nodes, num_edges
graph_name,
part_id,
server_count,
num_clients,
num_nodes,
num_edges,
use_graphbolt=False,
):
os.environ["DGL_NUM_SERVER"] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id
)
g = DistGraph(graph_name, gpb=gpb)
check_dist_graph_hetero(g, num_clients, num_nodes, num_edges)
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
check_dist_graph_hetero(
g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
)
def create_random_hetero():
......@@ -701,7 +785,9 @@ def create_random_hetero():
return g
def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
def check_dist_graph_hetero(
g, num_clients, num_nodes, num_edges, use_graphbolt=False
):
# Test API
for ntype in num_nodes:
assert ntype in g.ntypes
......@@ -754,6 +840,12 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
assert expect_except
# Test edge_subgraph
if use_graphbolt:
with pytest.raises(
AssertionError, match="find_edges is not supported in GraphBolt."
):
g.edge_subgraph({"r1": eids})
else:
sg = g.edge_subgraph({"r1": eids})
assert sg.num_edges() == len(eids)
assert F.array_equal(sg.edata[dgl.EID], eids)
......@@ -827,7 +919,9 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
print("end")
def check_server_client_hetero(shared_mem, num_servers, num_clients):
def check_server_client_hetero(
shared_mem, num_servers, num_clients, use_graphbolt=False
):
prepare_dist(num_servers)
g = create_random_hetero()
......@@ -835,6 +929,9 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
num_parts = 1
graph_name = "dist_graph_test_3"
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
if use_graphbolt:
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
dgl_partition_to_graphbolt(part_config)
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
......@@ -843,7 +940,14 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
for serv_id in range(num_servers):
p = ctx.Process(
target=run_server,
args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
args=(
graph_name,
serv_id,
num_servers,
num_clients,
shared_mem,
use_graphbolt,
),
)
serv_ps.append(p)
p.start()
......@@ -862,6 +966,7 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
num_clients,
num_nodes,
num_edges,
use_graphbolt,
),
)
p.start()
......@@ -886,21 +991,23 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
@unittest.skipIf(
dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
def test_server_client():
@pytest.mark.parametrize("shared_mem", [True])
@pytest.mark.parametrize("num_servers", [1])
@pytest.mark.parametrize("num_clients", [1, 4])
@pytest.mark.parametrize("use_graphbolt", [True, False])
def test_server_client(shared_mem, num_servers, num_clients, use_graphbolt):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
check_server_client_hierarchy(False, 1, 4)
check_server_client_empty(True, 1, 1)
check_server_client_hetero(True, 1, 1)
check_server_client_hetero(False, 1, 1)
check_server_client(True, 1, 1)
check_server_client(False, 1, 1)
# [TODO][Rhett] Tests for multiple groups may fail sometimes and
# root cause is unknown. Let's disable them for now.
# check_server_client(True, 2, 2)
# check_server_client(True, 1, 1, 2)
# check_server_client(False, 1, 1, 2)
# check_server_client(True, 2, 2, 2)
# [Rui]
# 1. `disable_shared_mem=False` is not supported yet. Skip it.
# 2. `num_servers` > 1 does not work on single machine. Skip it.
for func in [
check_server_client,
check_server_client_hetero,
check_server_client_empty,
check_server_client_hierarchy,
]:
func(shared_mem, num_servers, num_clients, use_graphbolt=use_graphbolt)
@unittest.skip(reason="Skip due to glitch in CI")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment