Unverified Commit da16ebf0 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

Serialize server and client start for dist_graph_store test (#1736)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 087992f1
...@@ -51,14 +51,22 @@ def create_random_graph(n): ...@@ -51,14 +51,22 @@ def create_random_graph(n):
ig = create_graph_index(arr, readonly=True) ig = create_graph_index(arr, readonly=True)
return dgl.DGLGraph(ig) return dgl.DGLGraph(ig)
def run_server(graph_name, server_id, num_clients, shared_mem): def run_server(graph_name, server_id, num_clients, shared_mem, cond_v, shared_v):
g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients, graph_name, g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients, graph_name,
'/tmp/dist_graph/{}.json'.format(graph_name), '/tmp/dist_graph/{}.json'.format(graph_name),
disable_shared_mem=not shared_mem) disable_shared_mem=not shared_mem)
print('start server', server_id) print('start server', server_id)
cond_v.acquire()
shared_v.value += 1;
cond_v.notify()
cond_v.release()
g.start() g.start()
def run_client(graph_name, part_id, num_nodes, num_edges): def run_client(graph_name, part_id, num_nodes, num_edges, num_server, cond_v, shared_v):
cond_v.acquire()
while shared_v.value < num_server:
cond_v.wait()
cond_v.release()
gpb = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None) part_id, None)
g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb) g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb)
...@@ -133,9 +141,11 @@ def check_server_client(shared_mem): ...@@ -133,9 +141,11 @@ def check_server_client(shared_mem):
# let's just test on one partition for now. # let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
serv_ps = [] serv_ps = []
cond_v = Condition()
shared_v = Value('i', 0)
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
for serv_id in range(1): for serv_id in range(1):
p = ctx.Process(target=run_server, args=(graph_name, serv_id, 1, shared_mem)) p = ctx.Process(target=run_server, args=(graph_name, serv_id, 1, shared_mem, cond_v, shared_v))
serv_ps.append(p) serv_ps.append(p)
p.start() p.start()
...@@ -143,7 +153,7 @@ def check_server_client(shared_mem): ...@@ -143,7 +153,7 @@ def check_server_client(shared_mem):
for cli_id in range(1): for cli_id in range(1):
print('start client', cli_id) print('start client', cli_id)
p = ctx.Process(target=run_client, args=(graph_name, cli_id, g.number_of_nodes(), p = ctx.Process(target=run_client, args=(graph_name, cli_id, g.number_of_nodes(),
g.number_of_edges())) g.number_of_edges(), 1, cond_v, shared_v))
p.start() p.start()
cli_ps.append(p) cli_ps.append(p)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment