Unverified Commit f0fbbc16 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Fix the launch script. (#1977)

* update launch script

* check the correctness of launch script.

* fix.
parent 4b8eaf20
...@@ -7,6 +7,7 @@ import argparse ...@@ -7,6 +7,7 @@ import argparse
import signal import signal
import logging import logging
import time import time
import json
from threading import Thread from threading import Thread
def execute_remote(cmd, ip, thread_list): def execute_remote(cmd, ip, thread_list):
...@@ -26,6 +27,8 @@ def submit_jobs(args, udf_command): ...@@ -26,6 +27,8 @@ def submit_jobs(args, udf_command):
hosts = [] hosts = []
thread_list = [] thread_list = []
server_count_per_machine = 0 server_count_per_machine = 0
# Get the IP addresses of the cluster.
ip_config = args.workspace + '/' + args.ip_config ip_config = args.workspace + '/' + args.ip_config
with open(ip_config) as f: with open(ip_config) as f:
for line in f: for line in f:
...@@ -34,11 +37,20 @@ def submit_jobs(args, udf_command): ...@@ -34,11 +37,20 @@ def submit_jobs(args, udf_command):
count = int(count) count = int(count)
server_count_per_machine = count server_count_per_machine = count
hosts.append((ip, port)) hosts.append((ip, port))
assert args.num_client % len(hosts) == 0
client_count_per_machine = int(args.num_client / len(hosts)) # Get partition info of the graph data
part_config = args.workspace + '/' + args.part_config
with open(part_config) as conf_f:
part_metadata = json.load(conf_f)
assert 'num_parts' in part_metadata, 'num_parts does not exist.'
# The number of partitions must match the number of machines in the cluster.
assert part_metadata['num_parts'] == len(hosts), \
'The number of graph partitions has to match the number of machines in the cluster.'
tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts)
# launch server tasks # launch server tasks
server_cmd = 'DGL_ROLE=server' server_cmd = 'DGL_ROLE=server'
server_cmd = server_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(args.num_client) server_cmd = server_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(tot_num_clients)
server_cmd = server_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config) server_cmd = server_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config)
server_cmd = server_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config) server_cmd = server_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config)
for i in range(len(hosts)*server_count_per_machine): for i in range(len(hosts)*server_count_per_machine):
...@@ -49,7 +61,7 @@ def submit_jobs(args, udf_command): ...@@ -49,7 +61,7 @@ def submit_jobs(args, udf_command):
execute_remote(cmd, ip, thread_list) execute_remote(cmd, ip, thread_list)
# launch client tasks # launch client tasks
client_cmd = 'DGL_DIST_MODE="distributed" DGL_ROLE=client' client_cmd = 'DGL_DIST_MODE="distributed" DGL_ROLE=client'
client_cmd = client_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(args.num_client) client_cmd = client_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(tot_num_clients)
client_cmd = client_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config) client_cmd = client_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config)
client_cmd = client_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config) client_cmd = client_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config)
if os.environ.get('OMP_NUM_THREADS') is not None: if os.environ.get('OMP_NUM_THREADS') is not None:
...@@ -58,7 +70,7 @@ def submit_jobs(args, udf_command): ...@@ -58,7 +70,7 @@ def submit_jobs(args, udf_command):
client_cmd = client_cmd + ' ' + 'PYTHONPATH=' + os.environ.get('PYTHONPATH') client_cmd = client_cmd + ' ' + 'PYTHONPATH=' + os.environ.get('PYTHONPATH')
torch_cmd = '-m torch.distributed.launch' torch_cmd = '-m torch.distributed.launch'
torch_cmd = torch_cmd + ' ' + '--nproc_per_node=' + str(client_count_per_machine) torch_cmd = torch_cmd + ' ' + '--nproc_per_node=' + str(args.num_trainers)
torch_cmd = torch_cmd + ' ' + '--nnodes=' + str(len(hosts)) torch_cmd = torch_cmd + ' ' + '--nnodes=' + str(len(hosts))
torch_cmd = torch_cmd + ' ' + '--node_rank=' + str(0) torch_cmd = torch_cmd + ' ' + '--node_rank=' + str(0)
torch_cmd = torch_cmd + ' ' + '--master_addr=' + str(hosts[0][0]) torch_cmd = torch_cmd + ' ' + '--master_addr=' + str(hosts[0][0])
...@@ -85,15 +97,18 @@ def main(): ...@@ -85,15 +97,18 @@ def main():
help='Path of user directory of distributed tasks. \ help='Path of user directory of distributed tasks. \
This is used to specify a destination location where \ This is used to specify a destination location where \
the contents of current directory will be rsyncd') the contents of current directory will be rsyncd')
parser.add_argument('--num_client', type=int, parser.add_argument('--num_trainers', type=int,
help='Total number of client processes in the cluster') help='The number of trainer processes per machine')
parser.add_argument('--num_samplers', type=int,
help='The number of sampler processes per trainer process')
parser.add_argument('--part_config', type=str, parser.add_argument('--part_config', type=str,
help='The file (in workspace) of the partition config') help='The file (in workspace) of the partition config')
parser.add_argument('--ip_config', type=str, parser.add_argument('--ip_config', type=str,
help='The file (in workspace) of IP configuration for server processes') help='The file (in workspace) of IP configuration for server processes')
args, udf_command = parser.parse_known_args() args, udf_command = parser.parse_known_args()
assert len(udf_command) == 1, 'Please provide user command line.' assert len(udf_command) == 1, 'Please provide user command line.'
assert args.num_client > 0, '--num_client must be a positive number.' assert args.num_trainers > 0, '--num_trainers must be a positive number.'
assert args.num_samplers >= 0
udf_command = str(udf_command[0]) udf_command = str(udf_command[0])
if 'python' not in udf_command: if 'python' not in udf_command:
raise RuntimeError("DGL launching script can only support Python executable file.") raise RuntimeError("DGL launching script can only support Python executable file.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment