Unverified Commit a49ec9e5 authored by 9rum's avatar 9rum Committed by GitHub
Browse files

[Tools] use torchrun instead of torch.distributed.launch (#6304)

parent 6aba92e9
...@@ -21,7 +21,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase): ...@@ -21,7 +21,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
master_port=1234, master_port=1234,
) )
expected = ( expected = (
"python3.7 -m torch.distributed.launch " "python3.7 -m torch.distributed.run "
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 " "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
"--master_port=1234 path/to/some/trainer.py arg1 arg2" "--master_port=1234 path/to/some/trainer.py arg1 arg2"
) )
...@@ -41,7 +41,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase): ...@@ -41,7 +41,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
master_port=1234, master_port=1234,
) )
expected = ( expected = (
"cd path/to && python3.7 -m torch.distributed.launch " "cd path/to && python3.7 -m torch.distributed.run "
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 " "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
"--master_port=1234 path/to/some/trainer.py arg1 arg2" "--master_port=1234 path/to/some/trainer.py arg1 arg2"
) )
...@@ -68,7 +68,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase): ...@@ -68,7 +68,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
master_port=1234, master_port=1234,
) )
expected = ( expected = (
"{python_bin} -m torch.distributed.launch ".format( "{python_bin} -m torch.distributed.run ".format(
python_bin=py_bin python_bin=py_bin
) )
+ "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 " + "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
...@@ -221,7 +221,7 @@ def test_submit_jobs(): ...@@ -221,7 +221,7 @@ def test_submit_jobs():
assert "DGL_ROLE=client" in cmd assert "DGL_ROLE=client" in cmd
assert "DGL_GROUP_ID=0" in cmd assert "DGL_GROUP_ID=0" in cmd
assert ( assert (
f"python3 -m torch.distributed.launch --nproc_per_node={args.num_trainers} --nnodes={num_machines}" f"python3 -m torch.distributed.run --nproc_per_node={args.num_trainers} --nnodes={num_machines}"
in cmd in cmd
) )
assert "--master_addr=127.0.0" in cmd assert "--master_addr=127.0.0" in cmd
......
...@@ -17,7 +17,7 @@ from typing import Optional ...@@ -17,7 +17,7 @@ from typing import Optional
def cleanup_proc(get_all_remote_pids, conn): def cleanup_proc(get_all_remote_pids, conn):
"""This process tries to clean up the remote training tasks.""" """This process tries to clean up the remote training tasks."""
print("cleanupu process runs") print("cleanup process runs")
# This process should not handle SIGINT. # This process should not handle SIGINT.
signal.signal(signal.SIGINT, signal.SIG_IGN) signal.signal(signal.SIGINT, signal.SIG_IGN)
...@@ -228,7 +228,7 @@ def construct_torch_dist_launcher_cmd( ...@@ -228,7 +228,7 @@ def construct_torch_dist_launcher_cmd(
cmd_str. cmd_str.
""" """
torch_cmd_template = ( torch_cmd_template = (
"-m torch.distributed.launch " "-m torch.distributed.run "
"--nproc_per_node={nproc_per_node} " "--nproc_per_node={nproc_per_node} "
"--nnodes={nnodes} " "--nnodes={nnodes} "
"--node_rank={node_rank} " "--node_rank={node_rank} "
...@@ -252,10 +252,10 @@ def wrap_udf_in_torch_dist_launcher( ...@@ -252,10 +252,10 @@ def wrap_udf_in_torch_dist_launcher(
master_addr: str, master_addr: str,
master_port: int, master_port: int,
) -> str: ) -> str:
"""Wraps the user-defined function (udf_command) with the torch.distributed.launch module. """Wraps the user-defined function (udf_command) with the torch.distributed.run module.
Example: if udf_command is "python3 run/some/trainer.py arg1 arg2", then new_df_command becomes: Example: if udf_command is "python3 run/some/trainer.py arg1 arg2", then new_df_command becomes:
"python3 -m torch.distributed.launch <TORCH DIST ARGS> run/some/trainer.py arg1 arg2 "python3 -m torch.distributed.run <TORCH DIST ARGS> run/some/trainer.py arg1 arg2
udf_command is assumed to consist of pre-commands (optional) followed by the python launcher script (required): udf_command is assumed to consist of pre-commands (optional) followed by the python launcher script (required):
Examples: Examples:
...@@ -310,7 +310,7 @@ def wrap_udf_in_torch_dist_launcher( ...@@ -310,7 +310,7 @@ def wrap_udf_in_torch_dist_launcher(
# transforms the udf_command from: # transforms the udf_command from:
# python path/to/dist_trainer.py arg0 arg1 # python path/to/dist_trainer.py arg0 arg1
# to: # to:
# python -m torch.distributed.launch [DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1 # python -m torch.distributed.run [DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1
# Note: if there are multiple python commands in `udf_command`, this may do the Wrong Thing, eg launch each # Note: if there are multiple python commands in `udf_command`, this may do the Wrong Thing, eg launch each
# python command within the torch distributed launcher. # python command within the torch distributed launcher.
new_udf_command = udf_command.replace( new_udf_command = udf_command.replace(
...@@ -593,7 +593,7 @@ def submit_jobs(args, udf_command, dry_run=False): ...@@ -593,7 +593,7 @@ def submit_jobs(args, udf_command, dry_run=False):
master_port = get_available_port(master_addr) master_port = get_available_port(master_addr)
for node_id, host in enumerate(hosts): for node_id, host in enumerate(hosts):
ip, _ = host ip, _ = host
# Transform udf_command to follow torch's dist launcher format: `PYTHON_BIN -m torch.distributed.launch ... UDF` # Transform udf_command to follow torch's dist launcher format: `PYTHON_BIN -m torch.distributed.run ... UDF`
torch_dist_udf_command = wrap_udf_in_torch_dist_launcher( torch_dist_udf_command = wrap_udf_in_torch_dist_launcher(
udf_command=udf_command, udf_command=udf_command,
num_trainers=args.num_trainers, num_trainers=args.num_trainers,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment