Unverified Commit c40bbf4f authored by Eric Kim's avatar Eric Kim Committed by GitHub
Browse files

[Tools] Refactor tools/launch.py to handle more python binary names (#3205)

* Refactors torch dist launcher udf-wrap code to handle more python versions

* minor changes
parent 213e27f0
import unittest
from tools.launch import wrap_udf_in_torch_dist_launcher
class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
"""wrap_udf_in_torch_dist_launcher()"""
def test_simple(self):
# test that a simple udf_command is correctly wrapped
udf_command = "python3.7 path/to/some/trainer.py arg1 arg2"
wrapped_udf_command = wrap_udf_in_torch_dist_launcher(
udf_command=udf_command,
num_trainers=2,
num_nodes=2,
node_rank=1,
master_addr="127.0.0.1",
master_port=1234,
)
expected = "python3.7 -m torch.distributed.launch " \
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 " \
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
self.assertEqual(wrapped_udf_command, expected)
def test_chained_udf(self):
# test that a chained udf_command is properly handled
udf_command = "cd path/to && python3.7 path/to/some/trainer.py arg1 arg2"
wrapped_udf_command = wrap_udf_in_torch_dist_launcher(
udf_command=udf_command,
num_trainers=2,
num_nodes=2,
node_rank=1,
master_addr="127.0.0.1",
master_port=1234,
)
expected = "cd path/to && python3.7 -m torch.distributed.launch " \
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 " \
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
self.assertEqual(wrapped_udf_command, expected)
def test_py_versions(self):
# test that this correctly handles different py versions/binaries
py_binaries = (
"python3.7", "python3.8", "python3.9", "python3", "python"
)
udf_command = "{python_bin} path/to/some/trainer.py arg1 arg2"
for py_bin in py_binaries:
wrapped_udf_command = wrap_udf_in_torch_dist_launcher(
udf_command=udf_command.format(python_bin=py_bin),
num_trainers=2,
num_nodes=2,
node_rank=1,
master_addr="127.0.0.1",
master_port=1234,
)
expected = "{python_bin} -m torch.distributed.launch ".format(python_bin=py_bin) + \
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 " \
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
self.assertEqual(wrapped_udf_command, expected)
if __name__ == '__main__':
unittest.main()
...@@ -158,6 +158,111 @@ def get_all_remote_pids(hosts, ssh_port, udf_command): ...@@ -158,6 +158,111 @@ def get_all_remote_pids(hosts, ssh_port, udf_command):
remote_pids[(ip, ssh_port)] = pids remote_pids[(ip, ssh_port)] = pids
return remote_pids return remote_pids
def construct_torch_dist_launcher_cmd(
num_trainers: int,
num_nodes: int,
node_rank: int,
master_addr: str,
master_port: int
) -> str:
"""Constructs the torch distributed launcher command.
Helper function.
Args:
num_trainers:
num_nodes:
node_rank:
master_addr:
master_port:
Returns:
cmd_str.
"""
torch_cmd_template = "-m torch.distributed.launch " \
"--nproc_per_node={nproc_per_node} " \
"--nnodes={nnodes} " \
"--node_rank={node_rank} " \
"--master_addr={master_addr} " \
"--master_port={master_port}"
return torch_cmd_template.format(
nproc_per_node=num_trainers,
nnodes=num_nodes,
node_rank=node_rank,
master_addr=master_addr,
master_port=master_port
)
def wrap_udf_in_torch_dist_launcher(
udf_command: str,
num_trainers: int,
num_nodes: int,
node_rank: int,
master_addr: str,
master_port: int,
) -> str:
"""Wraps the user-defined function (udf_command) with the torch.distributed.launch module.
Example: if udf_command is "python3 run/some/trainer.py arg1 arg2", then new_df_command becomes:
"python3 -m torch.distributed.launch <TORCH DIST ARGS> run/some/trainer.py arg1 arg2
udf_command is assumed to consist of pre-commands (optional) followed by the python launcher script (required):
Examples:
# simple
python3.7 path/to/some/trainer.py arg1 arg2
# multi-commands
(cd some/dir && python3.7 path/to/some/trainer.py arg1 arg2)
IMPORTANT: If udf_command consists of multiple python commands, then this will result in undefined behavior.
Args:
udf_command:
num_trainers:
num_nodes:
node_rank:
master_addr:
master_port:
Returns:
"""
torch_dist_cmd = construct_torch_dist_launcher_cmd(
num_trainers=num_trainers,
num_nodes=num_nodes,
node_rank=node_rank,
master_addr=master_addr,
master_port=master_port
)
# Auto-detect the python binary that kicks off the distributed trainer code.
# Note: This allowlist order matters, this will match with the FIRST matching entry. Thus, please add names to this
# from most-specific to least-specific order eg:
# (python3.7, python3.8) -> (python3)
# The allowed python versions are from this: https://www.dgl.ai/pages/start.html
python_bin_allowlist = (
"python3.6", "python3.7", "python3.8", "python3.9", "python3",
# for backwards compatibility, accept python2 but technically DGL is a py3 library, so this is not recommended
"python2.7", "python2",
)
# If none of the candidate python bins match, then we go with the default `python`
python_bin = "python"
for candidate_python_bin in python_bin_allowlist:
if candidate_python_bin in udf_command:
python_bin = candidate_python_bin
break
# transforms the udf_command from:
# python path/to/dist_trainer.py arg0 arg1
# to:
# python -m torch.distributed.launch [DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1
# Note: if there are multiple python commands in `udf_command`, this may do the Wrong Thing, eg launch each
# python command within the torch distributed launcher.
new_udf_command = udf_command.replace(python_bin, f"{python_bin} {torch_dist_cmd}")
return new_udf_command
def submit_jobs(args, udf_command): def submit_jobs(args, udf_command):
"""Submit distributed jobs (server and client processes) via ssh""" """Submit distributed jobs (server and client processes) via ssh"""
hosts = [] hosts = []
...@@ -219,22 +324,18 @@ def submit_jobs(args, udf_command): ...@@ -219,22 +324,18 @@ def submit_jobs(args, udf_command):
client_cmd = client_cmd + ' ' + 'PYTHONPATH=' + os.environ.get('PYTHONPATH') client_cmd = client_cmd + ' ' + 'PYTHONPATH=' + os.environ.get('PYTHONPATH')
client_cmd = client_cmd + ' ' + 'DGL_GRAPH_FORMAT=' + str(args.graph_format) client_cmd = client_cmd + ' ' + 'DGL_GRAPH_FORMAT=' + str(args.graph_format)
torch_cmd = '-m torch.distributed.launch'
torch_cmd = torch_cmd + ' ' + '--nproc_per_node=' + str(args.num_trainers)
torch_cmd = torch_cmd + ' ' + '--nnodes=' + str(len(hosts))
torch_cmd = torch_cmd + ' ' + '--node_rank=' + str(0)
torch_cmd = torch_cmd + ' ' + '--master_addr=' + str(hosts[0][0])
torch_cmd = torch_cmd + ' ' + '--master_port=' + str(1234)
for node_id, host in enumerate(hosts): for node_id, host in enumerate(hosts):
ip, _ = host ip, _ = host
new_torch_cmd = torch_cmd.replace('node_rank=0', 'node_rank='+str(node_id)) # Transform udf_command to follow torch's dist launcher format: `PYTHON_BIN -m torch.distributed.launch ... UDF`
if 'python3' in udf_command: torch_dist_udf_command = wrap_udf_in_torch_dist_launcher(
new_udf_command = udf_command.replace('python3', 'python3 ' + new_torch_cmd) udf_command=udf_command,
elif 'python2' in udf_command: num_trainers=args.num_trainers,
new_udf_command = udf_command.replace('python2', 'python2 ' + new_torch_cmd) num_nodes=len(hosts),
else: node_rank=node_id,
new_udf_command = udf_command.replace('python', 'python ' + new_torch_cmd) master_addr=hosts[0][0],
cmd = client_cmd + ' ' + new_udf_command master_port=1234,
)
cmd = client_cmd + ' ' + torch_dist_udf_command
cmd = 'cd ' + str(args.workspace) + '; ' + cmd cmd = 'cd ' + str(args.workspace) + '; ' + cmd
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username)) thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment