Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f0fbbc16
Unverified
Commit
f0fbbc16
authored
Aug 08, 2020
by
Da Zheng
Committed by
GitHub
Aug 08, 2020
Browse files
[Distributed] Fix the launch script. (#1977)
* update launch script * check the correctness of launch script. * fix.
parent
4b8eaf20
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
8 deletions
+23
-8
tools/launch.py
tools/launch.py
+23
-8
No files found.
tools/launch.py
View file @
f0fbbc16
...
@@ -7,6 +7,7 @@ import argparse
...
@@ -7,6 +7,7 @@ import argparse
import
signal
import
signal
import
logging
import
logging
import
time
import
time
import
json
from
threading
import
Thread
from
threading
import
Thread
def
execute_remote
(
cmd
,
ip
,
thread_list
):
def
execute_remote
(
cmd
,
ip
,
thread_list
):
...
@@ -26,6 +27,8 @@ def submit_jobs(args, udf_command):
...
@@ -26,6 +27,8 @@ def submit_jobs(args, udf_command):
hosts
=
[]
hosts
=
[]
thread_list
=
[]
thread_list
=
[]
server_count_per_machine
=
0
server_count_per_machine
=
0
# Get the IP addresses of the cluster.
ip_config
=
args
.
workspace
+
'/'
+
args
.
ip_config
ip_config
=
args
.
workspace
+
'/'
+
args
.
ip_config
with
open
(
ip_config
)
as
f
:
with
open
(
ip_config
)
as
f
:
for
line
in
f
:
for
line
in
f
:
...
@@ -34,11 +37,20 @@ def submit_jobs(args, udf_command):
...
@@ -34,11 +37,20 @@ def submit_jobs(args, udf_command):
count
=
int
(
count
)
count
=
int
(
count
)
server_count_per_machine
=
count
server_count_per_machine
=
count
hosts
.
append
((
ip
,
port
))
hosts
.
append
((
ip
,
port
))
assert
args
.
num_client
%
len
(
hosts
)
==
0
client_count_per_machine
=
int
(
args
.
num_client
/
len
(
hosts
))
# Get partition info of the graph data
part_config
=
args
.
workspace
+
'/'
+
args
.
part_config
with
open
(
part_config
)
as
conf_f
:
part_metadata
=
json
.
load
(
conf_f
)
assert
'num_parts'
in
part_metadata
,
'num_parts does not exist.'
# The number of partitions must match the number of machines in the cluster.
assert
part_metadata
[
'num_parts'
]
==
len
(
hosts
),
\
'The number of graph partitions has to match the number of machines in the cluster.'
tot_num_clients
=
args
.
num_trainers
*
(
1
+
args
.
num_samplers
)
*
len
(
hosts
)
# launch server tasks
# launch server tasks
server_cmd
=
'DGL_ROLE=server'
server_cmd
=
'DGL_ROLE=server'
server_cmd
=
server_cmd
+
' '
+
'DGL_NUM_CLIENT='
+
str
(
args
.
num_client
)
server_cmd
=
server_cmd
+
' '
+
'DGL_NUM_CLIENT='
+
str
(
tot_
num_client
s
)
server_cmd
=
server_cmd
+
' '
+
'DGL_CONF_PATH='
+
str
(
args
.
part_config
)
server_cmd
=
server_cmd
+
' '
+
'DGL_CONF_PATH='
+
str
(
args
.
part_config
)
server_cmd
=
server_cmd
+
' '
+
'DGL_IP_CONFIG='
+
str
(
args
.
ip_config
)
server_cmd
=
server_cmd
+
' '
+
'DGL_IP_CONFIG='
+
str
(
args
.
ip_config
)
for
i
in
range
(
len
(
hosts
)
*
server_count_per_machine
):
for
i
in
range
(
len
(
hosts
)
*
server_count_per_machine
):
...
@@ -49,7 +61,7 @@ def submit_jobs(args, udf_command):
...
@@ -49,7 +61,7 @@ def submit_jobs(args, udf_command):
execute_remote
(
cmd
,
ip
,
thread_list
)
execute_remote
(
cmd
,
ip
,
thread_list
)
# launch client tasks
# launch client tasks
client_cmd
=
'DGL_DIST_MODE="distributed" DGL_ROLE=client'
client_cmd
=
'DGL_DIST_MODE="distributed" DGL_ROLE=client'
client_cmd
=
client_cmd
+
' '
+
'DGL_NUM_CLIENT='
+
str
(
args
.
num_client
)
client_cmd
=
client_cmd
+
' '
+
'DGL_NUM_CLIENT='
+
str
(
tot_
num_client
s
)
client_cmd
=
client_cmd
+
' '
+
'DGL_CONF_PATH='
+
str
(
args
.
part_config
)
client_cmd
=
client_cmd
+
' '
+
'DGL_CONF_PATH='
+
str
(
args
.
part_config
)
client_cmd
=
client_cmd
+
' '
+
'DGL_IP_CONFIG='
+
str
(
args
.
ip_config
)
client_cmd
=
client_cmd
+
' '
+
'DGL_IP_CONFIG='
+
str
(
args
.
ip_config
)
if
os
.
environ
.
get
(
'OMP_NUM_THREADS'
)
is
not
None
:
if
os
.
environ
.
get
(
'OMP_NUM_THREADS'
)
is
not
None
:
...
@@ -58,7 +70,7 @@ def submit_jobs(args, udf_command):
...
@@ -58,7 +70,7 @@ def submit_jobs(args, udf_command):
client_cmd
=
client_cmd
+
' '
+
'PYTHONPATH='
+
os
.
environ
.
get
(
'PYTHONPATH'
)
client_cmd
=
client_cmd
+
' '
+
'PYTHONPATH='
+
os
.
environ
.
get
(
'PYTHONPATH'
)
torch_cmd
=
'-m torch.distributed.launch'
torch_cmd
=
'-m torch.distributed.launch'
torch_cmd
=
torch_cmd
+
' '
+
'--nproc_per_node='
+
str
(
client_count_per_mach
ine
)
torch_cmd
=
torch_cmd
+
' '
+
'--nproc_per_node='
+
str
(
args
.
num_tra
ine
rs
)
torch_cmd
=
torch_cmd
+
' '
+
'--nnodes='
+
str
(
len
(
hosts
))
torch_cmd
=
torch_cmd
+
' '
+
'--nnodes='
+
str
(
len
(
hosts
))
torch_cmd
=
torch_cmd
+
' '
+
'--node_rank='
+
str
(
0
)
torch_cmd
=
torch_cmd
+
' '
+
'--node_rank='
+
str
(
0
)
torch_cmd
=
torch_cmd
+
' '
+
'--master_addr='
+
str
(
hosts
[
0
][
0
])
torch_cmd
=
torch_cmd
+
' '
+
'--master_addr='
+
str
(
hosts
[
0
][
0
])
...
@@ -85,15 +97,18 @@ def main():
...
@@ -85,15 +97,18 @@ def main():
help
=
'Path of user directory of distributed tasks.
\
help
=
'Path of user directory of distributed tasks.
\
This is used to specify a destination location where
\
This is used to specify a destination location where
\
the contents of current directory will be rsyncd'
)
the contents of current directory will be rsyncd'
)
parser
.
add_argument
(
'--num_client'
,
type
=
int
,
parser
.
add_argument
(
'--num_trainers'
,
type
=
int
,
help
=
'Total number of client processes in the cluster'
)
help
=
'The number of trainer processes per machine'
)
parser
.
add_argument
(
'--num_samplers'
,
type
=
int
,
help
=
'The number of sampler processes per trainer process'
)
parser
.
add_argument
(
'--part_config'
,
type
=
str
,
parser
.
add_argument
(
'--part_config'
,
type
=
str
,
help
=
'The file (in workspace) of the partition config'
)
help
=
'The file (in workspace) of the partition config'
)
parser
.
add_argument
(
'--ip_config'
,
type
=
str
,
parser
.
add_argument
(
'--ip_config'
,
type
=
str
,
help
=
'The file (in workspace) of IP configuration for server processes'
)
help
=
'The file (in workspace) of IP configuration for server processes'
)
args
,
udf_command
=
parser
.
parse_known_args
()
args
,
udf_command
=
parser
.
parse_known_args
()
assert
len
(
udf_command
)
==
1
,
'Please provide user command line.'
assert
len
(
udf_command
)
==
1
,
'Please provide user command line.'
assert
args
.
num_client
>
0
,
'--num_client must be a positive number.'
assert
args
.
num_trainers
>
0
,
'--num_trainers must be a positive number.'
assert
args
.
num_samplers
>=
0
udf_command
=
str
(
udf_command
[
0
])
udf_command
=
str
(
udf_command
[
0
])
if
'python'
not
in
udf_command
:
if
'python'
not
in
udf_command
:
raise
RuntimeError
(
"DGL launching script can only support Python executable file."
)
raise
RuntimeError
(
"DGL launching script can only support Python executable file."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment