Unverified Commit 2d372e35 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Kill training jobs in distributed training (#2881)



* kill training jobs.

* update.

* fix.
Co-authored-by: default avatarZheng <dzzhen@3c22fba32af5.ant.amazon.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-71-112.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-73-81.ec2.internal>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 45ec21b0
......@@ -9,10 +9,70 @@ import logging
import time
import json
import multiprocessing
import re
from functools import partial
from threading import Thread
DEFAULT_PORT = 30050
def cleanup_proc(get_all_remote_pids, conn):
'''This process tries to clean up the remote training tasks.
'''
print('cleanupu process runs')
# This process should not handle SIGINT.
signal.signal(signal.SIGINT, signal.SIG_IGN)
data = conn.recv()
# If the launch process exits normally, this process doesn't need to do anything.
if data == 'exit':
sys.exit(0)
else:
remote_pids = get_all_remote_pids()
# Otherwise, we need to ssh to each machine and kill the training jobs.
for (ip, port), pids in remote_pids.items():
kill_process(ip, port, pids)
print('cleanup process exits')
def kill_process(ip, port, pids):
'''ssh to a remote machine and kill the specified processes.
'''
curr_pid = os.getpid()
killed_pids = []
# If we kill child processes first, the parent process may create more again. This happens
# to Python's process pool. After sorting, we always kill parent processes first.
pids.sort()
for pid in pids:
assert curr_pid != pid
print('kill process {} on {}:{}'.format(pid, ip, port), flush=True)
kill_cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'kill {}\''.format(pid)
subprocess.run(kill_cmd, shell=True)
killed_pids.append(pid)
# It's possible that some of the processes are not killed. Let's try again.
for i in range(3):
killed_pids = get_killed_pids(ip, port, killed_pids)
if len(killed_pids) == 0:
break
else:
killed_pids.sort()
for pid in killed_pids:
print('kill process {} on {}:{}'.format(pid, ip, port), flush=True)
kill_cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'kill -9 {}\''.format(pid)
subprocess.run(kill_cmd, shell=True)
def get_killed_pids(ip, port, killed_pids):
'''Get the process IDs that we want to kill but are still alive.
'''
killed_pids = [str(pid) for pid in killed_pids]
killed_pids = ','.join(killed_pids)
ps_cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'ps -p {} -h\''.format(killed_pids)
res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)
pids = []
for p in res.stdout.decode('utf-8').split('\n'):
l = p.split()
if len(l) > 0:
pids.append(int(l[0]))
return pids
def execute_remote(cmd, ip, port, thread_list):
"""execute command line on remote machine via ssh"""
cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'' + cmd + '\''
......@@ -25,6 +85,49 @@ def execute_remote(cmd, ip, port, thread_list):
thread.start()
thread_list.append(thread)
def get_remote_pids(ip, port, cmd_regex):
"""Get the process IDs that run the command in the remote machine.
"""
pids = []
curr_pid = os.getpid()
# Here we want to get the python processes. We may get some ssh processes, so we should filter them out.
ps_cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'ps -aux | grep python | grep -v StrictHostKeyChecking\''
res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)
for p in res.stdout.decode('utf-8').split('\n'):
l = p.split()
if len(l) < 2:
continue
# We only get the processes that run the specified command.
res = re.search(cmd_regex, p)
if res is not None and int(l[1]) != curr_pid:
pids.append(l[1])
pid_str = ','.join([str(pid) for pid in pids])
ps_cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'pgrep -P {}\''.format(pid_str)
res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)
pids1 = res.stdout.decode('utf-8').split('\n')
all_pids = []
for pid in set(pids + pids1):
if pid == '' or int(pid) == curr_pid:
continue
all_pids.append(int(pid))
all_pids.sort()
return all_pids
def get_all_remote_pids(hosts, ssh_port, udf_command):
'''Get all remote processes.
'''
remote_pids = {}
for node_id, host in enumerate(hosts):
ip, _ = host
# When creating training processes in remote machines, we may insert some arguments
# in the commands. We need to use regular expressions to match the modified command.
cmds = udf_command.split()
new_udf_command = ' .*'.join(cmds)
pids = get_remote_pids(ip, ssh_port, new_udf_command)
remote_pids[(ip, ssh_port)] = pids
return remote_pids
def submit_jobs(args, udf_command):
"""Submit distributed jobs (server and client processes) via ssh"""
hosts = []
......@@ -102,8 +205,25 @@ def submit_jobs(args, udf_command):
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
execute_remote(cmd, ip, args.ssh_port, thread_list)
# Start a cleanup process dedicated for cleaning up remote training jobs.
conn1,conn2 = multiprocessing.Pipe()
func = partial(get_all_remote_pids, hosts, args.ssh_port, udf_command)
process = multiprocessing.Process(target=cleanup_proc, args=(func, conn1))
process.start()
def signal_handler(signal, frame):
logging.info('Stop launcher')
# We need to tell the cleanup process to kill remote training jobs.
conn2.send('cleanup')
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
for thread in thread_list:
thread.join()
# The training processes complete. We should tell the cleanup process to exit.
conn2.send('exit')
process.join()
def main():
parser = argparse.ArgumentParser(description='Launch a distributed job')
......@@ -153,12 +273,7 @@ def main():
raise RuntimeError("DGL launching script can only support Python executable file.")
submit_jobs(args, udf_command)
def signal_handler(signal, frame):
logging.info('Stop launcher')
sys.exit(0)
if __name__ == '__main__':
fmt = '%(asctime)s %(levelname)s %(message)s'
logging.basicConfig(format=fmt, level=logging.INFO)
signal.signal(signal.SIGINT, signal_handler)
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment