Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a49ec9e5
Unverified
Commit
a49ec9e5
authored
Sep 11, 2023
by
9rum
Committed by
GitHub
Sep 11, 2023
Browse files
[Tools] use torchrun instead of torch.distributed.launch (#6304)
parent
6aba92e9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
10 deletions
+10
-10
tests/tools/test_launch.py
tests/tools/test_launch.py
+4
-4
tools/launch.py
tools/launch.py
+6
-6
No files found.
tests/tools/test_launch.py
View file @
a49ec9e5
...
@@ -21,7 +21,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
...
@@ -21,7 +21,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
master_port
=
1234
,
master_port
=
1234
,
)
)
expected
=
(
expected
=
(
"python3.7 -m torch.distributed.
launch
"
"python3.7 -m torch.distributed.
run
"
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
)
)
...
@@ -41,7 +41,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
...
@@ -41,7 +41,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
master_port
=
1234
,
master_port
=
1234
,
)
)
expected
=
(
expected
=
(
"cd path/to && python3.7 -m torch.distributed.
launch
"
"cd path/to && python3.7 -m torch.distributed.
run
"
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
)
)
...
@@ -68,7 +68,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
...
@@ -68,7 +68,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
master_port
=
1234
,
master_port
=
1234
,
)
)
expected
=
(
expected
=
(
"{python_bin} -m torch.distributed.
launch
"
.
format
(
"{python_bin} -m torch.distributed.
run
"
.
format
(
python_bin
=
py_bin
python_bin
=
py_bin
)
)
+
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
+
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
...
@@ -221,7 +221,7 @@ def test_submit_jobs():
...
@@ -221,7 +221,7 @@ def test_submit_jobs():
assert
"DGL_ROLE=client"
in
cmd
assert
"DGL_ROLE=client"
in
cmd
assert
"DGL_GROUP_ID=0"
in
cmd
assert
"DGL_GROUP_ID=0"
in
cmd
assert
(
assert
(
f
"python3 -m torch.distributed.
launch
--nproc_per_node=
{
args
.
num_trainers
}
--nnodes=
{
num_machines
}
"
f
"python3 -m torch.distributed.
run
--nproc_per_node=
{
args
.
num_trainers
}
--nnodes=
{
num_machines
}
"
in
cmd
in
cmd
)
)
assert
"--master_addr=127.0.0"
in
cmd
assert
"--master_addr=127.0.0"
in
cmd
...
...
tools/launch.py
View file @
a49ec9e5
...
@@ -17,7 +17,7 @@ from typing import Optional
...
@@ -17,7 +17,7 @@ from typing import Optional
def
cleanup_proc
(
get_all_remote_pids
,
conn
):
def
cleanup_proc
(
get_all_remote_pids
,
conn
):
"""This process tries to clean up the remote training tasks."""
"""This process tries to clean up the remote training tasks."""
print
(
"cleanup
u
process runs"
)
print
(
"cleanup process runs"
)
# This process should not handle SIGINT.
# This process should not handle SIGINT.
signal
.
signal
(
signal
.
SIGINT
,
signal
.
SIG_IGN
)
signal
.
signal
(
signal
.
SIGINT
,
signal
.
SIG_IGN
)
...
@@ -228,7 +228,7 @@ def construct_torch_dist_launcher_cmd(
...
@@ -228,7 +228,7 @@ def construct_torch_dist_launcher_cmd(
cmd_str.
cmd_str.
"""
"""
torch_cmd_template
=
(
torch_cmd_template
=
(
"-m torch.distributed.
launch
"
"-m torch.distributed.
run
"
"--nproc_per_node={nproc_per_node} "
"--nproc_per_node={nproc_per_node} "
"--nnodes={nnodes} "
"--nnodes={nnodes} "
"--node_rank={node_rank} "
"--node_rank={node_rank} "
...
@@ -252,10 +252,10 @@ def wrap_udf_in_torch_dist_launcher(
...
@@ -252,10 +252,10 @@ def wrap_udf_in_torch_dist_launcher(
master_addr
:
str
,
master_addr
:
str
,
master_port
:
int
,
master_port
:
int
,
)
->
str
:
)
->
str
:
"""Wraps the user-defined function (udf_command) with the torch.distributed.
launch
module.
"""Wraps the user-defined function (udf_command) with the torch.distributed.
run
module.
Example: if udf_command is "python3 run/some/trainer.py arg1 arg2", then new_df_command becomes:
Example: if udf_command is "python3 run/some/trainer.py arg1 arg2", then new_df_command becomes:
"python3 -m torch.distributed.
launch
<TORCH DIST ARGS> run/some/trainer.py arg1 arg2
"python3 -m torch.distributed.
run
<TORCH DIST ARGS> run/some/trainer.py arg1 arg2
udf_command is assumed to consist of pre-commands (optional) followed by the python launcher script (required):
udf_command is assumed to consist of pre-commands (optional) followed by the python launcher script (required):
Examples:
Examples:
...
@@ -310,7 +310,7 @@ def wrap_udf_in_torch_dist_launcher(
...
@@ -310,7 +310,7 @@ def wrap_udf_in_torch_dist_launcher(
# transforms the udf_command from:
# transforms the udf_command from:
# python path/to/dist_trainer.py arg0 arg1
# python path/to/dist_trainer.py arg0 arg1
# to:
# to:
# python -m torch.distributed.
launch
[DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1
# python -m torch.distributed.
run
[DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1
# Note: if there are multiple python commands in `udf_command`, this may do the Wrong Thing, eg launch each
# Note: if there are multiple python commands in `udf_command`, this may do the Wrong Thing, eg launch each
# python command within the torch distributed launcher.
# python command within the torch distributed launcher.
new_udf_command
=
udf_command
.
replace
(
new_udf_command
=
udf_command
.
replace
(
...
@@ -593,7 +593,7 @@ def submit_jobs(args, udf_command, dry_run=False):
...
@@ -593,7 +593,7 @@ def submit_jobs(args, udf_command, dry_run=False):
master_port
=
get_available_port
(
master_addr
)
master_port
=
get_available_port
(
master_addr
)
for
node_id
,
host
in
enumerate
(
hosts
):
for
node_id
,
host
in
enumerate
(
hosts
):
ip
,
_
=
host
ip
,
_
=
host
# Transform udf_command to follow torch's dist launcher format: `PYTHON_BIN -m torch.distributed.
launch
... UDF`
# Transform udf_command to follow torch's dist launcher format: `PYTHON_BIN -m torch.distributed.
run
... UDF`
torch_dist_udf_command
=
wrap_udf_in_torch_dist_launcher
(
torch_dist_udf_command
=
wrap_udf_in_torch_dist_launcher
(
udf_command
=
udf_command
,
udf_command
=
udf_command
,
num_trainers
=
args
.
num_trainers
,
num_trainers
=
args
.
num_trainers
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment