Unverified Commit 17701174 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Test] Fix test of distributed graph (#1748)

* Revert "Serialize server and client start for dist_graph_store test (#1736)"

This reverts commit da16ebf0.

* sleep
parent f80fc9bd
...@@ -51,22 +51,15 @@ def create_random_graph(n): ...@@ -51,22 +51,15 @@ def create_random_graph(n):
ig = create_graph_index(arr, readonly=True) ig = create_graph_index(arr, readonly=True)
return dgl.DGLGraph(ig) return dgl.DGLGraph(ig)
def run_server(graph_name, server_id, num_clients, shared_mem, cond_v, shared_v): def run_server(graph_name, server_id, num_clients, shared_mem):
g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients, graph_name, g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients, graph_name,
'/tmp/dist_graph/{}.json'.format(graph_name), '/tmp/dist_graph/{}.json'.format(graph_name),
disable_shared_mem=not shared_mem) disable_shared_mem=not shared_mem)
print('start server', server_id) print('start server', server_id)
cond_v.acquire()
shared_v.value += 1;
cond_v.notify()
cond_v.release()
g.start() g.start()
def run_client(graph_name, part_id, num_nodes, num_edges, num_server, cond_v, shared_v): def run_client(graph_name, part_id, num_nodes, num_edges):
cond_v.acquire() time.sleep(5)
while shared_v.value < num_server:
cond_v.wait()
cond_v.release()
gpb = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None) part_id, None)
g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb) g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb)
...@@ -141,11 +134,9 @@ def check_server_client(shared_mem): ...@@ -141,11 +134,9 @@ def check_server_client(shared_mem):
# let's just test on one partition for now. # let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
serv_ps = [] serv_ps = []
cond_v = Condition()
shared_v = Value('i', 0)
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
for serv_id in range(1): for serv_id in range(1):
p = ctx.Process(target=run_server, args=(graph_name, serv_id, 1, shared_mem, cond_v, shared_v)) p = ctx.Process(target=run_server, args=(graph_name, serv_id, 1, shared_mem))
serv_ps.append(p) serv_ps.append(p)
p.start() p.start()
...@@ -153,7 +144,7 @@ def check_server_client(shared_mem): ...@@ -153,7 +144,7 @@ def check_server_client(shared_mem):
for cli_id in range(1): for cli_id in range(1):
print('start client', cli_id) print('start client', cli_id)
p = ctx.Process(target=run_client, args=(graph_name, cli_id, g.number_of_nodes(), p = ctx.Process(target=run_client, args=(graph_name, cli_id, g.number_of_nodes(),
g.number_of_edges(), 1, cond_v, shared_v)) g.number_of_edges()))
p.start() p.start()
cli_ps.append(p) cli_ps.append(p)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment