Unverified Commit 9501ed6a authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] master port should be fixed for all trainers (#4108)

* [Dist] master port should be fixed for all trainers

* add tests for tools/launch.py
parent 92e77330
...@@ -42,4 +42,8 @@ export OMP_NUM_THREADS=1 ...@@ -42,4 +42,8 @@ export OMP_NUM_THREADS=1
export DMLC_LOG_DEBUG=1 export DMLC_LOG_DEBUG=1
if [ $2 != "gpu" ]; then if [ $2 != "gpu" ]; then
python3 -m pytest -v --capture=tee-sys --junitxml=pytest_distributed.xml tests/distributed/*.py || fail "distributed" python3 -m pytest -v --capture=tee-sys --junitxml=pytest_distributed.xml tests/distributed/*.py || fail "distributed"
if [ $DGLBACKEND == "pytorch" ]; then
python3 -m pip install filelock
PYTHONPATH=tools:$PYTHONPATH python3 -m pytest -v --capture=tee-sys --junitxml=pytest_tools.xml tests/tools/*.py || fail "tools"
fi
fi fi
import unittest import unittest
import json
from tools.launch import wrap_udf_in_torch_dist_launcher, wrap_cmd_with_local_envvars, construct_dgl_server_env_vars, \ import tempfile
construct_dgl_client_env_vars import os
from launch import *
class TestWrapUdfInTorchDistLauncher(unittest.TestCase): class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
"""wrap_udf_in_torch_dist_launcher()""" """wrap_udf_in_torch_dist_launcher()"""
...@@ -82,7 +82,8 @@ class TestConstructDglServerEnvVars(unittest.TestCase): ...@@ -82,7 +82,8 @@ class TestConstructDglServerEnvVars(unittest.TestCase):
part_config="path/to/part.config", part_config="path/to/part.config",
ip_config="path/to/ip.config", ip_config="path/to/ip.config",
num_servers=5, num_servers=5,
graph_format="csc" graph_format="csc",
keep_alive=False
), ),
( (
"DGL_ROLE=server " "DGL_ROLE=server "
...@@ -93,6 +94,7 @@ class TestConstructDglServerEnvVars(unittest.TestCase): ...@@ -93,6 +94,7 @@ class TestConstructDglServerEnvVars(unittest.TestCase):
"DGL_IP_CONFIG=path/to/ip.config " "DGL_IP_CONFIG=path/to/ip.config "
"DGL_NUM_SERVER=5 " "DGL_NUM_SERVER=5 "
"DGL_GRAPH_FORMAT=csc " "DGL_GRAPH_FORMAT=csc "
"DGL_KEEP_ALIVE=0 "
) )
) )
...@@ -110,6 +112,7 @@ class TestConstructDglClientEnvVars(unittest.TestCase): ...@@ -110,6 +112,7 @@ class TestConstructDglClientEnvVars(unittest.TestCase):
num_servers=3, num_servers=3,
graph_format="csc", graph_format="csc",
num_omp_threads=4, num_omp_threads=4,
group_id=0,
pythonpath="some/pythonpath/" pythonpath="some/pythonpath/"
), ),
( (
...@@ -122,6 +125,7 @@ class TestConstructDglClientEnvVars(unittest.TestCase): ...@@ -122,6 +125,7 @@ class TestConstructDglClientEnvVars(unittest.TestCase):
"DGL_NUM_SERVER=3 " "DGL_NUM_SERVER=3 "
"DGL_GRAPH_FORMAT=csc " "DGL_GRAPH_FORMAT=csc "
"OMP_NUM_THREADS=4 " "OMP_NUM_THREADS=4 "
"DGL_GROUP_ID=0 "
"PYTHONPATH=some/pythonpath/ " "PYTHONPATH=some/pythonpath/ "
) )
) )
...@@ -135,6 +139,7 @@ class TestConstructDglClientEnvVars(unittest.TestCase): ...@@ -135,6 +139,7 @@ class TestConstructDglClientEnvVars(unittest.TestCase):
num_servers=3, num_servers=3,
graph_format="csc", graph_format="csc",
num_omp_threads=4, num_omp_threads=4,
group_id=0
), ),
( (
"DGL_DIST_MODE=distributed " "DGL_DIST_MODE=distributed "
...@@ -146,9 +151,65 @@ class TestConstructDglClientEnvVars(unittest.TestCase): ...@@ -146,9 +151,65 @@ class TestConstructDglClientEnvVars(unittest.TestCase):
"DGL_NUM_SERVER=3 " "DGL_NUM_SERVER=3 "
"DGL_GRAPH_FORMAT=csc " "DGL_GRAPH_FORMAT=csc "
"OMP_NUM_THREADS=4 " "OMP_NUM_THREADS=4 "
"DGL_GROUP_ID=0 "
) )
) )
def test_submit_jobs():
class Args():
pass
args = Args()
with tempfile.TemporaryDirectory() as test_dir:
num_machines = 8
ip_config = os.path.join(test_dir, 'ip_config.txt')
with open(ip_config, 'w') as f:
for i in range(num_machines):
f.write('{} {}\n'.format('127.0.0.'+str(i), 30050))
part_config = os.path.join(test_dir, 'ogb-products.json')
with open(part_config, 'w') as f:
json.dump({'num_parts': num_machines}, f)
args.num_trainers = 8
args.num_samplers = 1
args.num_servers = 4
args.workspace = test_dir
args.part_config = 'ogb-products.json'
args.ip_config = 'ip_config.txt'
args.server_name = 'ogb-products'
args.keep_alive = False
args.num_server_threads = 1
args.graph_format = 'csc'
args.extra_envs = ["NCCL_DEBUG=INFO"]
args.num_omp_threads = 1
udf_command = "python3 train_dist.py --num_epochs 10"
clients_cmd, servers_cmd = submit_jobs(args, udf_command, dry_run=True)
def common_checks():
assert 'cd ' + test_dir in cmd
assert 'export ' + args.extra_envs[0] in cmd
assert f'DGL_NUM_SAMPLER={args.num_samplers}' in cmd
assert f'DGL_NUM_CLIENT={args.num_trainers*(args.num_samplers+1)*num_machines}' in cmd
assert f'DGL_CONF_PATH={args.part_config}' in cmd
assert f'DGL_IP_CONFIG={args.ip_config}' in cmd
assert f'DGL_NUM_SERVER={args.num_servers}' in cmd
assert f'DGL_GRAPH_FORMAT={args.graph_format}' in cmd
assert f'OMP_NUM_THREADS={args.num_omp_threads}' in cmd
assert udf_command[len('python3 '):] in cmd
for cmd in clients_cmd:
common_checks()
assert 'DGL_DIST_MODE=distributed' in cmd
assert 'DGL_ROLE=client' in cmd
assert 'DGL_GROUP_ID=0' in cmd
assert f'python3 -m torch.distributed.launch --nproc_per_node={args.num_trainers} --nnodes={num_machines}' in cmd
assert '--master_addr=127.0.0' in cmd
assert '--master_port=1234' in cmd
for cmd in servers_cmd:
common_checks()
assert 'DGL_ROLE=server' in cmd
assert 'DGL_KEEP_ALIVE=0' in cmd
assert 'DGL_SERVER_ID=' in cmd
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -500,8 +500,12 @@ def get_available_port(ip): ...@@ -500,8 +500,12 @@ def get_available_port(ip):
return port return port
raise RuntimeError("Failed to get available port for ip~{}".format(ip)) raise RuntimeError("Failed to get available port for ip~{}".format(ip))
def submit_jobs(args, udf_command): def submit_jobs(args, udf_command, dry_run=False):
"""Submit distributed jobs (server and client processes) via ssh""" """Submit distributed jobs (server and client processes) via ssh"""
if dry_run:
print("Currently it's in dry run mode which means no jobs will be launched.")
servers_cmd = []
clients_cmd = []
hosts = [] hosts = []
thread_list = [] thread_list = []
server_count_per_machine = 0 server_count_per_machine = 0
...@@ -551,7 +555,9 @@ def submit_jobs(args, udf_command): ...@@ -551,7 +555,9 @@ def submit_jobs(args, udf_command):
cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur) cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur)
cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd
cmd = 'cd ' + str(args.workspace) + '; ' + cmd cmd = 'cd ' + str(args.workspace) + '; ' + cmd
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username)) servers_cmd.append(cmd)
if not dry_run:
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
else: else:
print(f"Use running server {args.server_name}.") print(f"Use running server {args.server_name}.")
...@@ -568,6 +574,8 @@ def submit_jobs(args, udf_command): ...@@ -568,6 +574,8 @@ def submit_jobs(args, udf_command):
pythonpath=os.environ.get("PYTHONPATH", ""), pythonpath=os.environ.get("PYTHONPATH", ""),
) )
master_addr = hosts[0][0]
master_port = get_available_port(master_addr)
for node_id, host in enumerate(hosts): for node_id, host in enumerate(hosts):
ip, _ = host ip, _ = host
# Transform udf_command to follow torch's dist launcher format: `PYTHON_BIN -m torch.distributed.launch ... UDF` # Transform udf_command to follow torch's dist launcher format: `PYTHON_BIN -m torch.distributed.launch ... UDF`
...@@ -576,13 +584,19 @@ def submit_jobs(args, udf_command): ...@@ -576,13 +584,19 @@ def submit_jobs(args, udf_command):
num_trainers=args.num_trainers, num_trainers=args.num_trainers,
num_nodes=len(hosts), num_nodes=len(hosts),
node_rank=node_id, node_rank=node_id,
master_addr=hosts[0][0], master_addr=master_addr,
master_port=get_available_port(hosts[0][0]), master_port=master_port
) )
cmd = wrap_cmd_with_local_envvars(torch_dist_udf_command, client_env_vars) cmd = wrap_cmd_with_local_envvars(torch_dist_udf_command, client_env_vars)
cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd
cmd = 'cd ' + str(args.workspace) + '; ' + cmd cmd = 'cd ' + str(args.workspace) + '; ' + cmd
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username)) clients_cmd.append(cmd)
if not dry_run:
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
# return commands of clients/servers directly if in dry run mode
if dry_run:
return clients_cmd, servers_cmd
# Start a cleanup process dedicated for cleaning up remote training jobs. # Start a cleanup process dedicated for cleaning up remote training jobs.
conn1,conn2 = multiprocessing.Pipe() conn1,conn2 = multiprocessing.Pipe()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment