Unverified Commit df6b3250 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] enable DistGraph to load graphbolt partitions (#7048)

parent 942b17ab
...@@ -60,18 +60,21 @@ class InitGraphRequest(rpc.Request): ...@@ -60,18 +60,21 @@ class InitGraphRequest(rpc.Request):
with shared memory. with shared memory.
""" """
def __init__(self, graph_name): def __init__(self, graph_name, use_graphbolt):
self._graph_name = graph_name self._graph_name = graph_name
self._use_graphbolt = use_graphbolt
def __getstate__(self): def __getstate__(self):
return self._graph_name return self._graph_name, self._use_graphbolt
def __setstate__(self, state): def __setstate__(self, state):
self._graph_name = state self._graph_name, self._use_graphbolt = state
def process_request(self, server_state): def process_request(self, server_state):
if server_state.graph is None: if server_state.graph is None:
server_state.graph = _get_graph_from_shared_mem(self._graph_name) server_state.graph = _get_graph_from_shared_mem(
self._graph_name, self._use_graphbolt
)
return InitGraphResponse(self._graph_name) return InitGraphResponse(self._graph_name)
...@@ -153,13 +156,15 @@ def _exist_shared_mem_array(graph_name, name): ...@@ -153,13 +156,15 @@ def _exist_shared_mem_array(graph_name, name):
return exist_shared_mem_array(_get_edata_path(graph_name, name)) return exist_shared_mem_array(_get_edata_path(graph_name, name))
def _get_graph_from_shared_mem(graph_name): def _get_graph_from_shared_mem(graph_name, use_graphbolt):
"""Get the graph from the DistGraph server. """Get the graph from the DistGraph server.
The DistGraph server puts the graph structure of the local partition in the shared memory. The DistGraph server puts the graph structure of the local partition in the shared memory.
The client can access the graph structure and some metadata on nodes and edges directly The client can access the graph structure and some metadata on nodes and edges directly
through shared memory to reduce the overhead of data access. through shared memory to reduce the overhead of data access.
""" """
if use_graphbolt:
return gb.load_from_shared_memory(graph_name)
g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory( g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(
graph_name graph_name
) )
...@@ -524,6 +529,8 @@ class DistGraph: ...@@ -524,6 +529,8 @@ class DistGraph:
part_config : str, optional part_config : str, optional
The path of partition configuration file generated by The path of partition configuration file generated by
:py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode. :py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode.
use_graphbolt : bool, optional
Whether to load GraphBolt partition. Default: False.
Examples Examples
-------- --------
...@@ -557,9 +564,15 @@ class DistGraph: ...@@ -557,9 +564,15 @@ class DistGraph:
manually setting up servers and trainers. The setup is not fully tested yet. manually setting up servers and trainers. The setup is not fully tested yet.
""" """
def __init__(self, graph_name, gpb=None, part_config=None): def __init__(
self, graph_name, gpb=None, part_config=None, use_graphbolt=False
):
self.graph_name = graph_name self.graph_name = graph_name
self._use_graphbolt = use_graphbolt
if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone": if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone":
assert (
use_graphbolt is False
), "GraphBolt is not supported in standalone mode."
assert ( assert (
part_config is not None part_config is not None
), "When running in the standalone model, the partition config file is required" ), "When running in the standalone model, the partition config file is required"
...@@ -600,7 +613,9 @@ class DistGraph: ...@@ -600,7 +613,9 @@ class DistGraph:
self._init(gpb) self._init(gpb)
# Tell the backup servers to load the graph structure from shared memory. # Tell the backup servers to load the graph structure from shared memory.
for server_id in range(self._client.num_servers): for server_id in range(self._client.num_servers):
rpc.send_request(server_id, InitGraphRequest(graph_name)) rpc.send_request(
server_id, InitGraphRequest(graph_name, use_graphbolt)
)
for server_id in range(self._client.num_servers): for server_id in range(self._client.num_servers):
rpc.recv_response() rpc.recv_response()
self._client.barrier() self._client.barrier()
...@@ -625,7 +640,9 @@ class DistGraph: ...@@ -625,7 +640,9 @@ class DistGraph:
assert ( assert (
self._client is not None self._client is not None
), "Distributed module is not initialized. Please call dgl.distributed.initialize." ), "Distributed module is not initialized. Please call dgl.distributed.initialize."
self._g = _get_graph_from_shared_mem(self.graph_name) self._g = _get_graph_from_shared_mem(
self.graph_name, self._use_graphbolt
)
self._gpb = get_shared_mem_partition_book(self.graph_name) self._gpb = get_shared_mem_partition_book(self.graph_name)
if self._gpb is None: if self._gpb is None:
self._gpb = gpb self._gpb = gpb
...@@ -682,10 +699,10 @@ class DistGraph: ...@@ -682,10 +699,10 @@ class DistGraph:
self._edata_store[etype] = data self._edata_store[etype] = data
def __getstate__(self): def __getstate__(self):
return self.graph_name, self._gpb return self.graph_name, self._gpb, self._use_graphbolt
def __setstate__(self, state): def __setstate__(self, state):
self.graph_name, gpb = state self.graph_name, gpb, self._use_graphbolt = state
self._init(gpb) self._init(gpb)
self._init_ndata_store() self._init_ndata_store()
...@@ -1230,6 +1247,9 @@ class DistGraph: ...@@ -1230,6 +1247,9 @@ class DistGraph:
tensor tensor
The destination node ID array. The destination node ID array.
""" """
assert (
self._use_graphbolt is False
), "find_edges is not supported in GraphBolt."
if etype is None: if etype is None:
assert ( assert (
len(self.etypes) == 1 len(self.etypes) == 1
......
...@@ -13,11 +13,13 @@ from multiprocessing import Condition, Manager, Process, Value ...@@ -13,11 +13,13 @@ from multiprocessing import Condition, Manager, Process, Value
import backend as F import backend as F
import dgl import dgl
import dgl.graphbolt as gb
import numpy as np import numpy as np
import pytest import pytest
import torch as th import torch as th
from dgl.data.utils import load_graphs, save_graphs from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import ( from dgl.distributed import (
dgl_partition_to_graphbolt,
DistEmbedding, DistEmbedding,
DistGraph, DistGraph,
DistGraphServer, DistGraphServer,
...@@ -38,12 +40,33 @@ if os.name != "nt": ...@@ -38,12 +40,33 @@ if os.name != "nt":
import struct import struct
def _verify_dist_graph_server_dgl(g):
# verify dtype of underlying graph
cg = g.client_g
for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():
if k in cg.ndata:
assert (
F.dtype(cg.ndata[k]) == dtype
), "Data type of {} in ndata should be {}.".format(k, dtype)
if k in cg.edata:
assert (
F.dtype(cg.edata[k]) == dtype
), "Data type of {} in edata should be {}.".format(k, dtype)
def _verify_dist_graph_server_graphbolt(g):
graph = g.client_g
assert isinstance(graph, gb.FusedCSCSamplingGraph)
# [Rui][TODO] verify dtype of underlying graph.
def run_server( def run_server(
graph_name, graph_name,
server_id, server_id,
server_count, server_count,
num_clients, num_clients,
shared_mem, shared_mem,
use_graphbolt=False,
): ):
g = DistGraphServer( g = DistGraphServer(
server_id, server_id,
...@@ -53,19 +76,15 @@ def run_server( ...@@ -53,19 +76,15 @@ def run_server(
"/tmp/dist_graph/{}.json".format(graph_name), "/tmp/dist_graph/{}.json".format(graph_name),
disable_shared_mem=not shared_mem, disable_shared_mem=not shared_mem,
graph_format=["csc", "coo"], graph_format=["csc", "coo"],
use_graphbolt=use_graphbolt,
) )
print("start server", server_id) print(f"Starting server[{server_id}] with use_graphbolt={use_graphbolt}")
# verify dtype of underlying graph _verify = (
cg = g.client_g _verify_dist_graph_server_graphbolt
for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items(): if use_graphbolt
if k in cg.ndata: else _verify_dist_graph_server_dgl
assert ( )
F.dtype(cg.ndata[k]) == dtype _verify(g)
), "Data type of {} in ndata should be {}.".format(k, dtype)
if k in cg.edata:
assert (
F.dtype(cg.edata[k]) == dtype
), "Data type of {} in edata should be {}.".format(k, dtype)
g.start() g.start()
...@@ -110,18 +129,26 @@ def check_dist_graph_empty(g, num_clients, num_nodes, num_edges): ...@@ -110,18 +129,26 @@ def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
def run_client_empty( def run_client_empty(
graph_name, part_id, server_count, num_clients, num_nodes, num_edges graph_name,
part_id,
server_count,
num_clients,
num_nodes,
num_edges,
use_graphbolt=False,
): ):
os.environ["DGL_NUM_SERVER"] = str(server_count) os.environ["DGL_NUM_SERVER"] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt") dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book( gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id "/tmp/dist_graph/{}.json".format(graph_name), part_id
) )
g = DistGraph(graph_name, gpb=gpb) g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
check_dist_graph_empty(g, num_clients, num_nodes, num_edges) check_dist_graph_empty(g, num_clients, num_nodes, num_edges)
def check_server_client_empty(shared_mem, num_servers, num_clients): def check_server_client_empty(
shared_mem, num_servers, num_clients, use_graphbolt=False
):
prepare_dist(num_servers) prepare_dist(num_servers)
g = create_random_graph(10000) g = create_random_graph(10000)
...@@ -129,6 +156,9 @@ def check_server_client_empty(shared_mem, num_servers, num_clients): ...@@ -129,6 +156,9 @@ def check_server_client_empty(shared_mem, num_servers, num_clients):
num_parts = 1 num_parts = 1
graph_name = "dist_graph_test_1" graph_name = "dist_graph_test_1"
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph") partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
if use_graphbolt:
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
dgl_partition_to_graphbolt(part_config)
# let's just test on one partition for now. # let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
...@@ -137,7 +167,14 @@ def check_server_client_empty(shared_mem, num_servers, num_clients): ...@@ -137,7 +167,14 @@ def check_server_client_empty(shared_mem, num_servers, num_clients):
for serv_id in range(num_servers): for serv_id in range(num_servers):
p = ctx.Process( p = ctx.Process(
target=run_server, target=run_server,
args=(graph_name, serv_id, num_servers, num_clients, shared_mem), args=(
graph_name,
serv_id,
num_servers,
num_clients,
shared_mem,
use_graphbolt,
),
) )
serv_ps.append(p) serv_ps.append(p)
p.start() p.start()
...@@ -154,6 +191,7 @@ def check_server_client_empty(shared_mem, num_servers, num_clients): ...@@ -154,6 +191,7 @@ def check_server_client_empty(shared_mem, num_servers, num_clients):
num_clients, num_clients,
g.num_nodes(), g.num_nodes(),
g.num_edges(), g.num_edges(),
use_graphbolt,
), ),
) )
p.start() p.start()
...@@ -178,6 +216,7 @@ def run_client( ...@@ -178,6 +216,7 @@ def run_client(
num_nodes, num_nodes,
num_edges, num_edges,
group_id, group_id,
use_graphbolt=False,
): ):
os.environ["DGL_NUM_SERVER"] = str(server_count) os.environ["DGL_NUM_SERVER"] = str(server_count)
os.environ["DGL_GROUP_ID"] = str(group_id) os.environ["DGL_GROUP_ID"] = str(group_id)
...@@ -185,8 +224,10 @@ def run_client( ...@@ -185,8 +224,10 @@ def run_client(
gpb, graph_name, _, _ = load_partition_book( gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id "/tmp/dist_graph/{}.json".format(graph_name), part_id
) )
g = DistGraph(graph_name, gpb=gpb) g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
check_dist_graph(g, num_clients, num_nodes, num_edges) check_dist_graph(
g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
)
def run_emb_client( def run_emb_client(
...@@ -270,14 +311,20 @@ def check_dist_optim_store(rank, num_nodes, optimizer_states, save): ...@@ -270,14 +311,20 @@ def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
def run_client_hierarchy( def run_client_hierarchy(
graph_name, part_id, server_count, node_mask, edge_mask, return_dict graph_name,
part_id,
server_count,
node_mask,
edge_mask,
return_dict,
use_graphbolt=False,
): ):
os.environ["DGL_NUM_SERVER"] = str(server_count) os.environ["DGL_NUM_SERVER"] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt") dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book( gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id "/tmp/dist_graph/{}.json".format(graph_name), part_id
) )
g = DistGraph(graph_name, gpb=gpb) g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
node_mask = F.tensor(node_mask) node_mask = F.tensor(node_mask)
edge_mask = F.tensor(edge_mask) edge_mask = F.tensor(edge_mask)
nodes = node_split( nodes = node_split(
...@@ -355,7 +402,7 @@ def check_dist_emb(g, num_clients, num_nodes, num_edges): ...@@ -355,7 +402,7 @@ def check_dist_emb(g, num_clients, num_nodes, num_edges):
sys.exit(-1) sys.exit(-1)
def check_dist_graph(g, num_clients, num_nodes, num_edges): def check_dist_graph(g, num_clients, num_nodes, num_edges, use_graphbolt=False):
# Test API # Test API
assert g.num_nodes() == num_nodes assert g.num_nodes() == num_nodes
assert g.num_edges() == num_edges assert g.num_edges() == num_edges
...@@ -373,6 +420,12 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges): ...@@ -373,6 +420,12 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
assert np.all(F.asnumpy(feats == eids)) assert np.all(F.asnumpy(feats == eids))
# Test edge_subgraph # Test edge_subgraph
if use_graphbolt:
with pytest.raises(
AssertionError, match="find_edges is not supported in GraphBolt."
):
g.edge_subgraph(eids)
else:
sg = g.edge_subgraph(eids) sg = g.edge_subgraph(eids)
assert sg.num_edges() == len(eids) assert sg.num_edges() == len(eids)
assert F.array_equal(sg.edata[dgl.EID], eids) assert F.array_equal(sg.edata[dgl.EID], eids)
...@@ -522,7 +575,9 @@ def check_dist_emb_server_client( ...@@ -522,7 +575,9 @@ def check_dist_emb_server_client(
print("clients have terminated") print("clients have terminated")
def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): def check_server_client(
shared_mem, num_servers, num_clients, num_groups=1, use_graphbolt=False
):
prepare_dist(num_servers) prepare_dist(num_servers)
g = create_random_graph(10000) g = create_random_graph(10000)
...@@ -532,6 +587,9 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): ...@@ -532,6 +587,9 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1) g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1) g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph") partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
if use_graphbolt:
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
dgl_partition_to_graphbolt(part_config)
# let's just test on one partition for now. # let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
...@@ -546,6 +604,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): ...@@ -546,6 +604,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
num_servers, num_servers,
num_clients, num_clients,
shared_mem, shared_mem,
use_graphbolt,
), ),
) )
serv_ps.append(p) serv_ps.append(p)
...@@ -566,6 +625,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): ...@@ -566,6 +625,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
g.num_nodes(), g.num_nodes(),
g.num_edges(), g.num_edges(),
group_id, group_id,
use_graphbolt,
), ),
) )
p.start() p.start()
...@@ -582,7 +642,12 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): ...@@ -582,7 +642,12 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
print("clients have terminated") print("clients have terminated")
def check_server_client_hierarchy(shared_mem, num_servers, num_clients): def check_server_client_hierarchy(
shared_mem, num_servers, num_clients, use_graphbolt=False
):
if num_clients == 1:
# skip this test if there is only one client.
return
prepare_dist(num_servers) prepare_dist(num_servers)
g = create_random_graph(10000) g = create_random_graph(10000)
...@@ -598,6 +663,9 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients): ...@@ -598,6 +663,9 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
"/tmp/dist_graph", "/tmp/dist_graph",
num_trainers_per_machine=num_clients, num_trainers_per_machine=num_clients,
) )
if use_graphbolt:
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
dgl_partition_to_graphbolt(part_config)
# let's just test on one partition for now. # let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
...@@ -606,7 +674,14 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients): ...@@ -606,7 +674,14 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
for serv_id in range(num_servers): for serv_id in range(num_servers):
p = ctx.Process( p = ctx.Process(
target=run_server, target=run_server,
args=(graph_name, serv_id, num_servers, num_clients, shared_mem), args=(
graph_name,
serv_id,
num_servers,
num_clients,
shared_mem,
use_graphbolt,
),
) )
serv_ps.append(p) serv_ps.append(p)
p.start() p.start()
...@@ -633,6 +708,7 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients): ...@@ -633,6 +708,7 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
node_mask, node_mask,
edge_mask, edge_mask,
return_dict, return_dict,
use_graphbolt,
), ),
) )
p.start() p.start()
...@@ -658,15 +734,23 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients): ...@@ -658,15 +734,23 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
def run_client_hetero( def run_client_hetero(
graph_name, part_id, server_count, num_clients, num_nodes, num_edges graph_name,
part_id,
server_count,
num_clients,
num_nodes,
num_edges,
use_graphbolt=False,
): ):
os.environ["DGL_NUM_SERVER"] = str(server_count) os.environ["DGL_NUM_SERVER"] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt") dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book( gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id "/tmp/dist_graph/{}.json".format(graph_name), part_id
) )
g = DistGraph(graph_name, gpb=gpb) g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
check_dist_graph_hetero(g, num_clients, num_nodes, num_edges) check_dist_graph_hetero(
g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
)
def create_random_hetero(): def create_random_hetero():
...@@ -701,7 +785,9 @@ def create_random_hetero(): ...@@ -701,7 +785,9 @@ def create_random_hetero():
return g return g
def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges): def check_dist_graph_hetero(
g, num_clients, num_nodes, num_edges, use_graphbolt=False
):
# Test API # Test API
for ntype in num_nodes: for ntype in num_nodes:
assert ntype in g.ntypes assert ntype in g.ntypes
...@@ -754,6 +840,12 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges): ...@@ -754,6 +840,12 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
assert expect_except assert expect_except
# Test edge_subgraph # Test edge_subgraph
if use_graphbolt:
with pytest.raises(
AssertionError, match="find_edges is not supported in GraphBolt."
):
g.edge_subgraph({"r1": eids})
else:
sg = g.edge_subgraph({"r1": eids}) sg = g.edge_subgraph({"r1": eids})
assert sg.num_edges() == len(eids) assert sg.num_edges() == len(eids)
assert F.array_equal(sg.edata[dgl.EID], eids) assert F.array_equal(sg.edata[dgl.EID], eids)
...@@ -827,7 +919,9 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges): ...@@ -827,7 +919,9 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
print("end") print("end")
def check_server_client_hetero(shared_mem, num_servers, num_clients): def check_server_client_hetero(
shared_mem, num_servers, num_clients, use_graphbolt=False
):
prepare_dist(num_servers) prepare_dist(num_servers)
g = create_random_hetero() g = create_random_hetero()
...@@ -835,6 +929,9 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients): ...@@ -835,6 +929,9 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
num_parts = 1 num_parts = 1
graph_name = "dist_graph_test_3" graph_name = "dist_graph_test_3"
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph") partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
if use_graphbolt:
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
dgl_partition_to_graphbolt(part_config)
# let's just test on one partition for now. # let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
...@@ -843,7 +940,14 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients): ...@@ -843,7 +940,14 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
for serv_id in range(num_servers): for serv_id in range(num_servers):
p = ctx.Process( p = ctx.Process(
target=run_server, target=run_server,
args=(graph_name, serv_id, num_servers, num_clients, shared_mem), args=(
graph_name,
serv_id,
num_servers,
num_clients,
shared_mem,
use_graphbolt,
),
) )
serv_ps.append(p) serv_ps.append(p)
p.start() p.start()
...@@ -862,6 +966,7 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients): ...@@ -862,6 +966,7 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
num_clients, num_clients,
num_nodes, num_nodes,
num_edges, num_edges,
use_graphbolt,
), ),
) )
p.start() p.start()
...@@ -886,21 +991,23 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients): ...@@ -886,21 +991,23 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
@unittest.skipIf( @unittest.skipIf(
dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support" dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
) )
def test_server_client(): @pytest.mark.parametrize("shared_mem", [True])
@pytest.mark.parametrize("num_servers", [1])
@pytest.mark.parametrize("num_clients", [1, 4])
@pytest.mark.parametrize("use_graphbolt", [True, False])
def test_server_client(shared_mem, num_servers, num_clients, use_graphbolt):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
check_server_client_hierarchy(False, 1, 4) # [Rui]
check_server_client_empty(True, 1, 1) # 1. `disable_shared_mem=False` is not supported yet. Skip it.
check_server_client_hetero(True, 1, 1) # 2. `num_servers` > 1 does not work on single machine. Skip it.
check_server_client_hetero(False, 1, 1) for func in [
check_server_client(True, 1, 1) check_server_client,
check_server_client(False, 1, 1) check_server_client_hetero,
# [TODO][Rhett] Tests for multiple groups may fail sometimes and check_server_client_empty,
# root cause is unknown. Let's disable them for now. check_server_client_hierarchy,
# check_server_client(True, 2, 2) ]:
# check_server_client(True, 1, 1, 2) func(shared_mem, num_servers, num_clients, use_graphbolt=use_graphbolt)
# check_server_client(False, 1, 1, 2)
# check_server_client(True, 2, 2, 2)
@unittest.skip(reason="Skip due to glitch in CI") @unittest.skip(reason="Skip due to glitch in CI")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment