Unverified Commit efd0fc6c authored by Eric Kim's avatar Eric Kim Committed by GitHub
Browse files

[Tools] Adds --ssh_username to the dist launcher (tools/launch.py) (#3202)

parent 7ab659e2
...@@ -12,6 +12,7 @@ import multiprocessing ...@@ -12,6 +12,7 @@ import multiprocessing
import re import re
from functools import partial from functools import partial
from threading import Thread from threading import Thread
from typing import Optional
DEFAULT_PORT = 30050 DEFAULT_PORT = 30050
...@@ -73,17 +74,46 @@ def get_killed_pids(ip, port, killed_pids): ...@@ -73,17 +74,46 @@ def get_killed_pids(ip, port, killed_pids):
pids.append(int(l[0])) pids.append(int(l[0]))
return pids return pids
def execute_remote(cmd, ip, port, thread_list): def execute_remote(
"""execute command line on remote machine via ssh""" cmd: str,
cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'' + cmd + '\'' ip: str,
port: int,
username: Optional[str] = ""
) -> Thread:
"""Execute command line on remote machine via ssh.
Args:
cmd: User-defined command (udf) to execute on the remote host.
ip: The ip-address of the host to run the command on.
port: Port number that the host is listening on.
thread_list:
username: Optional. If given, this will specify a username to use when issuing commands over SSH.
Useful when your infra requires you to explicitly specify a username to avoid permission issues.
Returns:
thread: The Thread whose run() is to run the `cmd` on the remote host. Returns when the cmd completes
on the remote host.
"""
ip_prefix = ""
if username:
ip_prefix += "{username}@".format(username=username)
# Construct ssh command that executes `cmd` on the remote host
ssh_cmd = "ssh -o StrictHostKeyChecking=no -p {port} {ip_prefix}{ip} '{cmd}'".format(
port=str(port),
ip_prefix=ip_prefix,
ip=ip,
cmd=cmd,
)
# thread func to run the job # thread func to run the job
def run(cmd): def run(ssh_cmd):
subprocess.check_call(cmd, shell = True) subprocess.check_call(ssh_cmd, shell=True)
thread = Thread(target = run, args=(cmd,)) thread = Thread(target=run, args=(ssh_cmd,))
thread.setDaemon(True) thread.setDaemon(True)
thread.start() thread.start()
thread_list.append(thread) return thread
def get_remote_pids(ip, port, cmd_regex): def get_remote_pids(ip, port, cmd_regex):
"""Get the process IDs that run the command in the remote machine. """Get the process IDs that run the command in the remote machine.
...@@ -173,7 +203,7 @@ def submit_jobs(args, udf_command): ...@@ -173,7 +203,7 @@ def submit_jobs(args, udf_command):
cmd = server_cmd + ' ' + 'DGL_SERVER_ID=' + str(i) cmd = server_cmd + ' ' + 'DGL_SERVER_ID=' + str(i)
cmd = cmd + ' ' + udf_command cmd = cmd + ' ' + udf_command
cmd = 'cd ' + str(args.workspace) + '; ' + cmd cmd = 'cd ' + str(args.workspace) + '; ' + cmd
execute_remote(cmd, ip, args.ssh_port, thread_list) thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
# launch client tasks # launch client tasks
client_cmd = 'DGL_DIST_MODE="distributed" DGL_ROLE=client DGL_NUM_SAMPLER=' + str(args.num_samplers) client_cmd = 'DGL_DIST_MODE="distributed" DGL_ROLE=client DGL_NUM_SAMPLER=' + str(args.num_samplers)
...@@ -206,7 +236,7 @@ def submit_jobs(args, udf_command): ...@@ -206,7 +236,7 @@ def submit_jobs(args, udf_command):
new_udf_command = udf_command.replace('python', 'python ' + new_torch_cmd) new_udf_command = udf_command.replace('python', 'python ' + new_torch_cmd)
cmd = client_cmd + ' ' + new_udf_command cmd = client_cmd + ' ' + new_udf_command
cmd = 'cd ' + str(args.workspace) + '; ' + cmd cmd = 'cd ' + str(args.workspace) + '; ' + cmd
execute_remote(cmd, ip, args.ssh_port, thread_list) thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
# Start a cleanup process dedicated for cleaning up remote training jobs. # Start a cleanup process dedicated for cleaning up remote training jobs.
conn1,conn2 = multiprocessing.Pipe() conn1,conn2 = multiprocessing.Pipe()
...@@ -231,6 +261,12 @@ def submit_jobs(args, udf_command): ...@@ -231,6 +261,12 @@ def submit_jobs(args, udf_command):
def main(): def main():
parser = argparse.ArgumentParser(description='Launch a distributed job') parser = argparse.ArgumentParser(description='Launch a distributed job')
parser.add_argument('--ssh_port', type=int, default=22, help='SSH Port.') parser.add_argument('--ssh_port', type=int, default=22, help='SSH Port.')
parser.add_argument(
"--ssh_username", default="",
help="Optional. When issuing commands (via ssh) to cluster, use the provided username in the ssh cmd. "
"Example: If you provide --ssh_username=bob, then the ssh command will be like: 'ssh bob@1.2.3.4 CMD' "
"instead of 'ssh 1.2.3.4 CMD'"
)
parser.add_argument('--workspace', type=str, parser.add_argument('--workspace', type=str,
help='Path of user directory of distributed tasks. \ help='Path of user directory of distributed tasks. \
This is used to specify a destination location where \ This is used to specify a destination location where \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment