Unverified Commit 7e2ed9f8 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Merge branch 'master' into dist_part

parents 2cf4bd0a ee672c0b
...@@ -9,6 +9,7 @@ import logging ...@@ -9,6 +9,7 @@ import logging
import time import time
import json import json
import multiprocessing import multiprocessing
import queue
import re import re
from functools import partial from functools import partial
from threading import Thread from threading import Thread
...@@ -74,6 +75,7 @@ def get_killed_pids(ip, port, killed_pids): ...@@ -74,6 +75,7 @@ def get_killed_pids(ip, port, killed_pids):
def execute_remote( def execute_remote(
cmd: str, cmd: str,
state_q: queue.Queue,
ip: str, ip: str,
port: int, port: int,
username: Optional[str] = "" username: Optional[str] = ""
...@@ -82,6 +84,7 @@ def execute_remote( ...@@ -82,6 +84,7 @@ def execute_remote(
Args: Args:
cmd: User-defined command (udf) to execute on the remote host. cmd: User-defined command (udf) to execute on the remote host.
state_q: A queue collecting Thread exit states.
ip: The ip-address of the host to run the command on. ip: The ip-address of the host to run the command on.
port: Port number that the host is listening on. port: Port number that the host is listening on.
thread_list: thread_list:
...@@ -105,10 +108,17 @@ def execute_remote( ...@@ -105,10 +108,17 @@ def execute_remote(
) )
# thread func to run the job # thread func to run the job
def run(ssh_cmd): def run(ssh_cmd, state_q):
try:
subprocess.check_call(ssh_cmd, shell=True) subprocess.check_call(ssh_cmd, shell=True)
state_q.put(0)
thread = Thread(target=run, args=(ssh_cmd,)) except subprocess.CalledProcessError as err:
print(f"Called process error {err}")
state_q.put(err.returncode)
except Exception:
state_q.put(-1)
thread = Thread(target=run, args=(ssh_cmd, state_q,))
thread.setDaemon(True) thread.setDaemon(True)
thread.start() thread.start()
# sleep for a while in case of ssh is rejected by peer due to busy connection # sleep for a while in case of ssh is rejected by peer due to busy connection
...@@ -535,6 +545,7 @@ def submit_jobs(args, udf_command, dry_run=False): ...@@ -535,6 +545,7 @@ def submit_jobs(args, udf_command, dry_run=False):
assert part_metadata['num_parts'] == len(hosts), \ assert part_metadata['num_parts'] == len(hosts), \
'The number of graph partitions has to match the number of machines in the cluster.' 'The number of graph partitions has to match the number of machines in the cluster.'
state_q = queue.Queue()
tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts) tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts)
# launch server tasks # launch server tasks
if not has_alive_servers(args): if not has_alive_servers(args):
...@@ -557,7 +568,7 @@ def submit_jobs(args, udf_command, dry_run=False): ...@@ -557,7 +568,7 @@ def submit_jobs(args, udf_command, dry_run=False):
cmd = 'cd ' + str(args.workspace) + '; ' + cmd cmd = 'cd ' + str(args.workspace) + '; ' + cmd
servers_cmd.append(cmd) servers_cmd.append(cmd)
if not dry_run: if not dry_run:
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username)) thread_list.append(execute_remote(cmd, state_q, ip, args.ssh_port, username=args.ssh_username))
else: else:
print(f"Use running server {args.server_name}.") print(f"Use running server {args.server_name}.")
...@@ -592,7 +603,7 @@ def submit_jobs(args, udf_command, dry_run=False): ...@@ -592,7 +603,7 @@ def submit_jobs(args, udf_command, dry_run=False):
cmd = 'cd ' + str(args.workspace) + '; ' + cmd cmd = 'cd ' + str(args.workspace) + '; ' + cmd
clients_cmd.append(cmd) clients_cmd.append(cmd)
if not dry_run: if not dry_run:
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username)) thread_list.append(execute_remote(cmd, state_q, ip, args.ssh_port, username=args.ssh_username))
# return commands of clients/servers directly if in dry run mode # return commands of clients/servers directly if in dry run mode
if dry_run: if dry_run:
...@@ -612,12 +623,21 @@ def submit_jobs(args, udf_command, dry_run=False): ...@@ -612,12 +623,21 @@ def submit_jobs(args, udf_command, dry_run=False):
sys.exit(0) sys.exit(0)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
err = 0
for thread in thread_list: for thread in thread_list:
thread.join() thread.join()
err_code = state_q.get()
if err_code != 0:
# Record err_code
# We record one of the error if there are multiple
err = err_code
# The training processes complete. We should tell the cleanup process to exit. # The training processes complete. We should tell the cleanup process to exit.
conn2.send('exit') conn2.send('exit')
process.join() process.join()
if err != 0:
print("Task failed")
sys.exit(-1)
def main(): def main():
parser = argparse.ArgumentParser(description='Launch a distributed job') parser = argparse.ArgumentParser(description='Launch a distributed job')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment