Unverified Commit 9c8c162a authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Fix] sleep for a while when launching clients which will connect to … (#3704)

* [Fix] sleep for a while when launching clients which will connect to multiple servers

* pre-allocate more ports

* no multiple partitions on single machine
parent 701b4fcc
...@@ -80,7 +80,7 @@ def run_client_empty(graph_name, part_id, server_count, num_clients, num_nodes, ...@@ -80,7 +80,7 @@ def run_client_empty(graph_name, part_id, server_count, num_clients, num_nodes,
check_dist_graph_empty(g, num_clients, num_nodes, num_edges) check_dist_graph_empty(g, num_clients, num_nodes, num_edges)
def check_server_client_empty(shared_mem, num_servers, num_clients): def check_server_client_empty(shared_mem, num_servers, num_clients):
prepare_dist() prepare_dist(num_servers)
g = create_random_graph(10000) g = create_random_graph(10000)
# Partition the graph # Partition the graph
...@@ -281,7 +281,7 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges): ...@@ -281,7 +281,7 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
print('end') print('end')
def check_dist_emb_server_client(shared_mem, num_servers, num_clients, num_groups=1): def check_dist_emb_server_client(shared_mem, num_servers, num_clients, num_groups=1):
prepare_dist() prepare_dist(num_servers)
g = create_random_graph(10000) g = create_random_graph(10000)
# Partition the graph # Partition the graph
...@@ -311,6 +311,7 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients, num_group ...@@ -311,6 +311,7 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients, num_group
g.number_of_edges(), g.number_of_edges(),
group_id)) group_id))
p.start() p.start()
time.sleep(1) # avoid race condition when instantiating DistGraph
cli_ps.append(p) cli_ps.append(p)
for p in cli_ps: for p in cli_ps:
...@@ -328,7 +329,7 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients, num_group ...@@ -328,7 +329,7 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients, num_group
print('clients have terminated') print('clients have terminated')
def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
prepare_dist() prepare_dist(num_servers)
g = create_random_graph(10000) g = create_random_graph(10000)
# Partition the graph # Partition the graph
...@@ -357,6 +358,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): ...@@ -357,6 +358,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
p = ctx.Process(target=run_client, args=(graph_name, 0, num_servers, num_clients, g.number_of_nodes(), p = ctx.Process(target=run_client, args=(graph_name, 0, num_servers, num_clients, g.number_of_nodes(),
g.number_of_edges(), group_id)) g.number_of_edges(), group_id))
p.start() p.start()
time.sleep(1) # avoid race condition when instantiating DistGraph
cli_ps.append(p) cli_ps.append(p)
for p in cli_ps: for p in cli_ps:
p.join() p.join()
...@@ -372,7 +374,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): ...@@ -372,7 +374,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
print('clients have terminated') print('clients have terminated')
def check_server_client_hierarchy(shared_mem, num_servers, num_clients): def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
prepare_dist() prepare_dist(num_servers)
g = create_random_graph(10000) g = create_random_graph(10000)
# Partition the graph # Partition the graph
...@@ -535,7 +537,7 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges): ...@@ -535,7 +537,7 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
print('end') print('end')
def check_server_client_hetero(shared_mem, num_servers, num_clients): def check_server_client_hetero(shared_mem, num_servers, num_clients):
prepare_dist() prepare_dist(num_servers)
g = create_random_hetero() g = create_random_hetero()
# Partition the graph # Partition the graph
...@@ -641,7 +643,6 @@ def test_standalone_node_emb(): ...@@ -641,7 +643,6 @@ def test_standalone_node_emb():
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_split(): def test_split():
#prepare_dist()
g = create_random_graph(10000) g = create_random_graph(10000)
num_parts = 4 num_parts = 4
num_hops = 2 num_hops = 2
...@@ -696,7 +697,6 @@ def test_split(): ...@@ -696,7 +697,6 @@ def test_split():
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_split_even(): def test_split_even():
#prepare_dist(1)
g = create_random_graph(10000) g = create_random_graph(10000)
num_parts = 4 num_parts = 4
num_hops = 2 num_hops = 2
...@@ -763,8 +763,8 @@ def test_split_even(): ...@@ -763,8 +763,8 @@ def test_split_even():
assert np.all(all_nodes == F.asnumpy(all_nodes2)) assert np.all(all_nodes == F.asnumpy(all_nodes2))
assert np.all(all_edges == F.asnumpy(all_edges2)) assert np.all(all_edges == F.asnumpy(all_edges2))
def prepare_dist(): def prepare_dist(num_servers=1):
generate_ip_config("kv_ip_config.txt", 1, 1) generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)
if __name__ == '__main__': if __name__ == '__main__':
os.makedirs('/tmp/dist_graph', exist_ok=True) os.makedirs('/tmp/dist_graph', exist_ok=True)
......
...@@ -301,6 +301,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): ...@@ -301,6 +301,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
for group_id in range(num_groups): for group_id in range(num_groups):
p = ctx.Process(target=start_sample_client_shuffle, args=(client_id, tmpdir, num_server > 1, g, num_server, group_id)) p = ctx.Process(target=start_sample_client_shuffle, args=(client_id, tmpdir, num_server > 1, g, num_server, group_id))
p.start() p.start()
time.sleep(1) # avoid race condition when instantiating DistGraph
pclient_list.append(p) pclient_list.append(p)
for p in pclient_list: for p in pclient_list:
p.join() p.join()
...@@ -563,7 +564,7 @@ def test_rpc_sampling_shuffle(num_server): ...@@ -563,7 +564,7 @@ def test_rpc_sampling_shuffle(num_server):
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_sampling_shuffle(Path(tmpdirname), num_server, num_groups=5) check_rpc_sampling_shuffle(Path(tmpdirname), num_server, num_groups=2)
check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server)
......
...@@ -216,6 +216,10 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers): ...@@ -216,6 +216,10 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers):
@pytest.mark.parametrize("num_groups", [1, 2]) @pytest.mark.parametrize("num_groups", [1, 2])
def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle, num_groups): def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle, num_groups):
reset_envs() reset_envs()
# No multiple partitions on single machine for
# multiple client groups in case of race condition.
if num_groups > 1:
num_server = 1
generate_ip_config("mp_ip_config.txt", num_server, num_server) generate_ip_config("mp_ip_config.txt", num_server, num_server)
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
...@@ -246,6 +250,7 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle, ...@@ -246,6 +250,7 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle,
p = ctx.Process(target=start_dist_dataloader, args=( p = ctx.Process(target=start_dist_dataloader, args=(
trainer_id, tmpdir, num_server, drop_last, orig_nid, orig_eid, group_id)) trainer_id, tmpdir, num_server, drop_last, orig_nid, orig_eid, group_id))
p.start() p.start()
time.sleep(1) # avoid race condition when instantiating DistGraph
ptrainer_list.append(p) ptrainer_list.append(p)
for p in ptrainer_list: for p in ptrainer_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment