Unverified Commit efd0fc6c authored by Eric Kim's avatar Eric Kim Committed by GitHub
Browse files

[Tools] Adds --ssh_username to the dist launcher (tools/launch.py) (#3202)

parent 7ab659e2
......@@ -12,6 +12,7 @@ import multiprocessing
import re
from functools import partial
from threading import Thread
from typing import Optional
DEFAULT_PORT = 30050
......@@ -73,17 +74,46 @@ def get_killed_pids(ip, port, killed_pids):
pids.append(int(l[0]))
return pids
def execute_remote(cmd, ip, port, thread_list):
"""execute command line on remote machine via ssh"""
cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'' + cmd + '\''
def execute_remote(
cmd: str,
ip: str,
port: int,
username: Optional[str] = ""
) -> Thread:
"""Execute command line on remote machine via ssh.
Args:
cmd: User-defined command (udf) to execute on the remote host.
ip: The ip-address of the host to run the command on.
port: Port number that the host is listening on.
thread_list:
username: Optional. If given, this will specify a username to use when issuing commands over SSH.
Useful when your infra requires you to explicitly specify a username to avoid permission issues.
Returns:
thread: The Thread whose run() is to run the `cmd` on the remote host. Returns when the cmd completes
on the remote host.
"""
ip_prefix = ""
if username:
ip_prefix += "{username}@".format(username=username)
# Construct ssh command that executes `cmd` on the remote host
ssh_cmd = "ssh -o StrictHostKeyChecking=no -p {port} {ip_prefix}{ip} '{cmd}'".format(
port=str(port),
ip_prefix=ip_prefix,
ip=ip,
cmd=cmd,
)
# thread func to run the job
def run(cmd):
subprocess.check_call(cmd, shell = True)
def run(ssh_cmd):
subprocess.check_call(ssh_cmd, shell=True)
thread = Thread(target = run, args=(cmd,))
thread = Thread(target=run, args=(ssh_cmd,))
thread.setDaemon(True)
thread.start()
thread_list.append(thread)
return thread
def get_remote_pids(ip, port, cmd_regex):
"""Get the process IDs that run the command in the remote machine.
......@@ -173,7 +203,7 @@ def submit_jobs(args, udf_command):
cmd = server_cmd + ' ' + 'DGL_SERVER_ID=' + str(i)
cmd = cmd + ' ' + udf_command
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
execute_remote(cmd, ip, args.ssh_port, thread_list)
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
# launch client tasks
client_cmd = 'DGL_DIST_MODE="distributed" DGL_ROLE=client DGL_NUM_SAMPLER=' + str(args.num_samplers)
......@@ -206,7 +236,7 @@ def submit_jobs(args, udf_command):
new_udf_command = udf_command.replace('python', 'python ' + new_torch_cmd)
cmd = client_cmd + ' ' + new_udf_command
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
execute_remote(cmd, ip, args.ssh_port, thread_list)
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
# Start a cleanup process dedicated for cleaning up remote training jobs.
conn1,conn2 = multiprocessing.Pipe()
......@@ -231,6 +261,12 @@ def submit_jobs(args, udf_command):
def main():
parser = argparse.ArgumentParser(description='Launch a distributed job')
parser.add_argument('--ssh_port', type=int, default=22, help='SSH Port.')
parser.add_argument(
"--ssh_username", default="",
help="Optional. When issuing commands (via ssh) to cluster, use the provided username in the ssh cmd. "
"Example: If you provide --ssh_username=bob, then the ssh command will be like: 'ssh bob@1.2.3.4 CMD' "
"instead of 'ssh 1.2.3.4 CMD'"
)
parser.add_argument('--workspace', type=str,
help='Path of user directory of distributed tasks. \
This is used to specify a destination location where \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment