Unverified Commit ee672c0b authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

Let distributed training launch script report error when any trainer or kvserver fails. (#4437)



* Collect error reports

* update

* fix
Co-authored-by: default avatarroot <root@ip-10-0-80-128.ec2.internal>
parent d077d371
......@@ -9,6 +9,7 @@ import logging
import time
import json
import multiprocessing
import queue
import re
from functools import partial
from threading import Thread
......@@ -74,6 +75,7 @@ def get_killed_pids(ip, port, killed_pids):
def execute_remote(
cmd: str,
state_q: queue.Queue,
ip: str,
port: int,
username: Optional[str] = ""
......@@ -82,6 +84,7 @@ def execute_remote(
Args:
cmd: User-defined command (udf) to execute on the remote host.
state_q: A queue collecting Thread exit states.
ip: The ip-address of the host to run the command on.
port: Port number that the host is listening on.
thread_list:
......@@ -105,10 +108,17 @@ def execute_remote(
)
# thread func to run the job
def run(ssh_cmd):
def run(ssh_cmd, state_q):
try:
subprocess.check_call(ssh_cmd, shell=True)
thread = Thread(target=run, args=(ssh_cmd,))
state_q.put(0)
except subprocess.CalledProcessError as err:
print(f"Called process error {err}")
state_q.put(err.returncode)
except Exception:
state_q.put(-1)
thread = Thread(target=run, args=(ssh_cmd, state_q,))
thread.setDaemon(True)
thread.start()
# sleep for a while in case of ssh is rejected by peer due to busy connection
......@@ -535,6 +545,7 @@ def submit_jobs(args, udf_command, dry_run=False):
assert part_metadata['num_parts'] == len(hosts), \
'The number of graph partitions has to match the number of machines in the cluster.'
state_q = queue.Queue()
tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts)
# launch server tasks
if not has_alive_servers(args):
......@@ -557,7 +568,7 @@ def submit_jobs(args, udf_command, dry_run=False):
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
servers_cmd.append(cmd)
if not dry_run:
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
thread_list.append(execute_remote(cmd, state_q, ip, args.ssh_port, username=args.ssh_username))
else:
print(f"Use running server {args.server_name}.")
......@@ -592,7 +603,7 @@ def submit_jobs(args, udf_command, dry_run=False):
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
clients_cmd.append(cmd)
if not dry_run:
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
thread_list.append(execute_remote(cmd, state_q, ip, args.ssh_port, username=args.ssh_username))
# return commands of clients/servers directly if in dry run mode
if dry_run:
......@@ -612,12 +623,21 @@ def submit_jobs(args, udf_command, dry_run=False):
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
err = 0
for thread in thread_list:
thread.join()
err_code = state_q.get()
if err_code != 0:
# Record err_code
# We record one of the error if there are multiple
err = err_code
# The training processes complete. We should tell the cleanup process to exit.
conn2.send('exit')
process.join()
if err != 0:
print("Task failed")
sys.exit(-1)
def main():
parser = argparse.ArgumentParser(description='Launch a distributed job')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment