Unverified Commit cf6d1c92 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[CLI] refactored the launch CLI and fixed bugs in multi-node launching (#844)

* [cli] fixed multi-node job launching

* [cli] fixed a bug in version comparison

* [cli] support launching with env var

* [cli] fixed multi-node job launching

* [cli] fixed a bug in version comparison

* [cli] support launching with env var

* added docstring

* [cli] added extra launch arguments

* [cli] added default launch rdzv args

* [cli] fixed version comparison

* [cli] added docstring examples and requierment

* polish docstring

* polish code

* polish code
parent e5ea3fde
...@@ -5,27 +5,34 @@ from colossalai.context import Config ...@@ -5,27 +5,34 @@ from colossalai.context import Config
@click.command(help="Launch distributed training on a single node or multiple nodes", @click.command(help="Launch distributed training on a single node or multiple nodes",
context_settings=dict(ignore_unknown_options=True)) context_settings=dict(ignore_unknown_options=True))
@click.option("-H", "-host", "--host", type=str, default=None, help="the list of machines to launch") @click.option("-H",
@click.option("--hostfile", "-host",
"--host",
type=str, type=str,
default=None, default=None,
help="Hostfile path that defines the device pool available to the job (e.g. worker-name:number of slots)") help="the list of hostnames to launch in the format <host1>,<host2>")
@click.option( @click.option(
"--include", "--hostfile",
type=str, type=str,
default=None, default=None,
help= help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname")
"Specify computing devices to use during execution. String format is NODE_SPEC@NODE_SPEC where NODE_SPEC=<worker-name>:<list-of-slots>" @click.option("--include",
) type=str,
default=None,
help="Specify computing devices to use during execution. String format is <host1>,<host2>,"
" only effective when used with --hostfile.")
@click.option( @click.option(
"--exclude", "--exclude",
type=str, type=str,
default=None, default=None,
help= help=
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include." "Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --includ,"
) " only effective when used with --hostfile.")
@click.option("--num_nodes", type=int, default=-1, help="Total number of worker nodes to use.") @click.option("--num_nodes",
@click.option("--nproc_per_node", type=int, default=-1, help="Number of GPUs to use on each node.") type=int,
default=-1,
help="Total number of worker nodes to use, only effective when used with --hostfile.")
@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.")
@click.option("--master_port", @click.option("--master_port",
type=int, type=int,
default=29500, default=29500,
...@@ -35,34 +42,43 @@ from colossalai.context import Config ...@@ -35,34 +42,43 @@ from colossalai.context import Config
default="127.0.0.1", default="127.0.0.1",
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.") help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.")
@click.option( @click.option(
"--launcher", "--extra_launch_args",
type=click.Choice(['torch', 'openmpi', 'slurm'], case_sensitive=False), type=str,
default="torch", default=None,
help="(optional) choose launcher backend for multi-node training. Options currently include PDSH, OpenMPI, SLURM.") help=
@click.option("--launcher_args", "Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
type=str, "This will be converted to --arg1=1 --arg2=2 during execution")
default=None, @click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
help="(optional) pass launcher specific arguments as a single quoted argument.")
@click.argument("user_script", type=str) @click.argument("user_script", type=str)
@click.argument('user_args', nargs=-1) @click.argument('user_args', nargs=-1)
def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str, def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
master_port: int, launcher: str, launcher_args: str, user_script: str, user_args: str): master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None:
""" """
To launch multiple processes on a single node or multiple nodes via command line. To launch multiple processes on a single node or multiple nodes via command line.
Usage:: Usage::
# run on the current node with all available GPUs # run with 4 GPUs on the current node use default port 29500
colossalai run train.py colossalai run --nprocs_per_node 4 train.py
# run with only 2 GPUs on the current node # run with 2 GPUs on the current node at port 29550
colossalai run --nprocs_per_node 2 train.py colossalai run --nprocs_per_node 4 --master_port 29550 train.py
# run on two nodes # run on two nodes
colossalai run --host <host1>,<host2> train.py colossalai run --host <host1>,<host2> --master_addr host1 --nprocs_per_node 4 train.py
# run with hostfile # run with hostfile
colossalai run --hostfile <file_path> train.py colossalai run --hostfile <file_path> --master_addr <host> --nprocs_per_node 4 train.py
# run with hostfile with only included hosts
colossalai run --hostfile <file_path> --master_addr host1 --include host1,host2 --nprocs_per_node 4 train.py
# run with hostfile excluding the hosts selected
colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
""" """
if not user_script.endswith('.py'):
click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help')
exit()
args_dict = locals() args_dict = locals()
args = Config(args_dict) args = Config(args_dict)
args.user_args = list(args.user_args) args.user_args = list(args.user_args)
......
from typing import List
import socket
class HostInfo:
"""
A data class to store host connection-related data.
Args:
hostname (str): name or IP address of the host
port (str): the port for ssh connection
"""
def __init__(
self,
hostname: str,
port: str = None,
):
self.hostname = hostname
self.port = port
self.is_local_host = HostInfo.is_host_localhost(hostname, port)
@staticmethod
def is_host_localhost(hostname: str, port: str = None) -> None:
"""
Check if the host refers to the local machine.
Args:
hostname (str): name or IP address of the host
port (str): the port for ssh connection
Returns:
bool: True if it is local, False otherwise
"""
if port is None:
port = 22 # no port specified, lets just use the ssh port
hostname = socket.getfqdn(hostname)
if hostname in ("localhost", "127.0.0.1", "0.0.0.0"):
return True
localhost = socket.gethostname()
localaddrs = socket.getaddrinfo(localhost, port)
targetaddrs = socket.getaddrinfo(hostname, port)
for (family, socktype, proto, canonname, sockaddr) in localaddrs:
for (rfamily, rsocktype, rproto, rcanonname, rsockaddr) in targetaddrs:
if rsockaddr[0] == sockaddr[0]:
return True
return False
def __str__(self):
return f'hostname: {self.hostname}, port: {self.port}'
def __repr__(self):
return self.__str__()
class HostInfoList:
"""
A data class to store a list of HostInfo objects.
"""
def __init__(self):
self.hostinfo_list = []
def append(self, hostinfo: HostInfo) -> None:
"""
Add an HostInfo object to the list.
Args:
hostinfo (HostInfo): host information
"""
self.hostinfo_list.append(hostinfo)
def remove(self, hostname: str) -> None:
"""
Add an HostInfo object to the list.
Args:
hostname (str): the name of the host
"""
hostinfo = self.get_hostinfo(hostname)
self.hostinfo_list.remove(hostinfo)
def get_hostinfo(self, hostname: str) -> HostInfo:
"""
Return the HostInfo object which matches with the hostname.
Args:
hostname (str): the name of the host
Returns:
hostinfo (HostInfo): the HostInfo object which matches with the hostname
"""
for hostinfo in self.hostinfo_list:
if hostinfo.hostname == hostname:
return hostinfo
raise Exception(f"Hostname {hostname} is not found")
def has(self, hostname: str) -> bool:
"""
Check if the hostname has been added.
Args:
hostname (str): the name of the host
Returns:
bool: True if added, False otherwise
"""
for hostinfo in self.hostinfo_list:
if hostinfo.hostname == hostname:
return True
return False
def __iter__(self):
return iter(self.hostinfo_list)
def __len__(self):
return len(self.hostinfo_list)
import os import fabric
import sys from fabric import Connection
import shutil from .hostinfo import HostInfo, HostInfoList
from shlex import quote from multiprocessing import Pipe, Process
from abc import ABC, abstractmethod from multiprocessing import connection as mp_connection
import click
from colossalai.logging import get_dist_logger
def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
class MultiNodeRunner(ABC): send_conn: mp_connection.Connection, env: dict) -> None:
"""
def __init__(self, args): Use fabric connection to execute command on local or remote hosts.
self.args = args
self.user_arguments = self.args.user_args Args:
self.user_script = args.user_script hostinfo (HostInfo): host information
self.exports = {} workdir (str): the directory to execute the command
recv_conn (multiprocessing.connection.Connection): receive messages from the master sender
@abstractmethod send_conn (multiprocessing.connection.Connection): send messages to the master receiver
def backend_exists(self): env (dict): a dictionary for environment variables
"""Return whether the corresponding backend exists""" """
@abstractmethod fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
def get_cmd(self, environment, active_devices): finish = False
"""Return the command to execute on node""" env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()])
def add_export(self, key, var): # keep listening until exit
self.exports[key.strip()] = var.strip() while not finish:
# receive cmd
@property cmds = recv_conn.recv()
def name(self):
"""Return the name of the backend""" if cmds == 'exit':
return self.__class__.__name__ # exit from the loop
finish = True
break
class PDSHRunner(MultiNodeRunner): else:
# execute the commands
def __init__(self, args): try:
super().__init__(args) # cd to execute directory
with fab_conn.cd(workdir):
def backend_exists(self): # propagate the runtime environment
return shutil.which('pdsh') with fab_conn.prefix(f"export {env_msg}"):
if hostinfo.is_local_host:
@property # execute on the local machine
def name(self): fab_conn.local(cmds, hide=False)
return "pdsh" else:
# execute on the remote machine
def parse_user_args(self): fab_conn.run(cmds, hide=False)
return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args)) send_conn.send('success')
except:
def get_cmd(self, environment, active_devices, args): click.echo(f"Error: failed to run {cmds} on {hostinfo.hostname}")
environment['PDSH_RCMD_TYPE'] = 'ssh' send_conn.send('failure')
active_workers = ",".join(active_devices.keys()) # shutdown
print("Running on the following workers: %s" % active_workers) send_conn.send("finish")
fab_conn.close()
pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers]
exports = "" class MultiNodeRunner:
for key, val in self.exports.items(): """
exports += f"export {key}={quote(val)}; " A runner to execute commands on an array of machines. This runner
is inspired by Nezha (https://github.com/zhuzilin/NeZha).
# https://linux.die.net/man/1/pdsh """
# %n will be replaced by pdsh command
colossal_launch = [ def __init__(self):
exports, f"cd {os.path.abspath('.')};", sys.executable, "-u", "-m", "torch.distributed.launch", self.processes = {}
f"--nproc_per_node={args.nproc_per_node}", f"--master_addr={args.master_addr}", self.master_send_conns = {}
f"--master_port={args.master_port}" self.master_recv_conns = {}
]
return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments def connect(self, host_info_list: HostInfoList, workdir: str, env: dict) -> None:
"""
Establish connections to a list of hosts
Args:
host_info_list (HostInfoList): a list of HostInfo objects
workdir (str): the directory where command is executed
env (dict): environment variables to propagate to hosts
"""
for hostinfo in host_info_list:
master_send_conn, worker_recv_conn = Pipe()
master_recv_conn, worker_send_conn = Pipe()
p = Process(target=run_on_host, args=(hostinfo, workdir, worker_recv_conn, worker_send_conn, env))
p.start()
self.processes[hostinfo.hostname] = p
self.master_recv_conns[hostinfo.hostname] = master_recv_conn
self.master_send_conns[hostinfo.hostname] = master_send_conn
def send(self, hostinfo: HostInfo, cmd: str) -> None:
"""
Send a command to a local/remote host.
Args:
hostinfo (HostInfo): host information
cmd (str): the command to execute
"""
assert hostinfo.hostname in self.master_send_conns, \
f'{hostinfo} is not found in the current connections'
conn = self.master_send_conns[hostinfo.hostname]
conn.send(cmd)
def stop_all(self) -> None:
"""
Stop connections to all hosts.
"""
for hostname, conn in self.master_send_conns.items():
conn.send('exit')
def recv_from_all(self) -> dict:
"""
Receive messages from all hosts
Returns:
msg_from_node (dict): a dictionry which contains messages from each node
"""
msg_from_node = dict()
for hostname, conn in self.master_recv_conns.items():
msg_from_node[hostname] = conn.recv()
return msg_from_node
import click import click
import subprocess
import collections
import sys import sys
import os import os
import torch import torch
from colossalai.context import Config from colossalai.context import Config
from .multinode_runner import PDSHRunner from .multinode_runner import MultiNodeRunner
from copy import deepcopy from .hostinfo import HostInfo, HostInfoList
from typing import List
from packaging import version
# Constants that define our syntax
NODE_SEP = ','
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
"""
Parse the hostfile to obtain a list of hosts.
A hostfile should look like:
worker-0
worker-1
worker-2
...
Args:
hostfile_path (str): the path to the hostfile
ssh_port (int): the port to connect to the host
"""
def fetch_hostfile(hostfile_path):
if not os.path.isfile(hostfile_path): if not os.path.isfile(hostfile_path):
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}") click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
exit() exit()
# e.g., worker-0:16
with open(hostfile_path, 'r') as fd: with open(hostfile_path, 'r') as fd:
device_pool = collections.OrderedDict() device_pool = HostInfoList()
for line in fd.readlines(): for line in fd.readlines():
line = line.strip() line = line.strip()
if line == '': if line == '':
# skip empty lines # skip empty lines
continue continue
try:
hostname, slot_count = line.split(":")
slot_count = int(slot_count)
except ValueError as err:
click.echo(f"Error: Hostfile is not formatted correctly, expected <hostname>:<slot>, but found {line}")
exit()
if hostname in device_pool: # build the HostInfo object
hostname = line.strip()
hostinfo = HostInfo(hostname=hostname, port=ssh_port)
if device_pool.has(hostname):
click.echo(f"Error: found duplicate host {hostname} in the hostfile") click.echo(f"Error: found duplicate host {hostname} in the hostfile")
exit() exit()
device_pool[hostname] = slot_count
return device_pool
device_pool.append(hostinfo)
def _stable_remove_duplicates(data): return device_pool
# Create a new list in the same order as original but with duplicates
# removed, should never be more than ~16 elements so simple is best
new_list = []
for x in data:
if x not in new_list:
new_list.append(x)
return new_list
def parse_device_filter(host_info, include_str=None, exclude_str=None): def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList:
'''Parse an inclusion or exclusion string and filter a hostfile dictionary. '''Parse an inclusion or exclusion string and filter a hostfile dictionary.
Examples: Examples:
include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1.
slots [0, 2] on worker-1. exclude_str="worker-1" will use all available devices except worker-1.
exclude_str="worker-1:0" will use all available devices except
slot 0 on worker-1.
'''
# Constants that define our syntax Args:
NODE_SEP = '@' device_pool (HostInfoList): a list of HostInfo objects
SLOT_LIST_START = ':' include_str (str): --include option passed by user, default None
SLOT_SEP = ',' exclude_str (str): --exclude option passed by user, default None
Returns:
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
'''
# Ensure include/exclude are mutually exclusive # Ensure include/exclude are mutually exclusive
if include_str and exclude_str: if include_str and exclude_str:
...@@ -68,176 +75,207 @@ def parse_device_filter(host_info, include_str=None, exclude_str=None): ...@@ -68,176 +75,207 @@ def parse_device_filter(host_info, include_str=None, exclude_str=None):
# no-op # no-op
if include_str is None and exclude_str is None: if include_str is None and exclude_str is None:
return host_info return device_pool
# Either build from scratch or remove items # Either build from scratch or remove items
filtered_hosts = dict()
if include_str: if include_str:
parse_str = include_str parse_str = include_str
filtered_hosts = HostInfoList()
elif exclude_str: elif exclude_str:
filtered_hosts = deepcopy(host_info)
parse_str = exclude_str parse_str = exclude_str
filtered_hosts = device_pool
# foreach node in the list # foreach node in the list
for node_config in parse_str.split(NODE_SEP): for node_config in parse_str.split(NODE_SEP):
# Node can either be alone or node:slot,slot,slot hostname = node_config
if SLOT_LIST_START in node_config: hostinfo = device_pool.get_hostinfo(hostname)
hostname, slots = node_config.split(SLOT_LIST_START) # sanity check hostname
slots = [int(x) for x in slots.split(SLOT_SEP)] if not device_pool.has(hostname):
click.echo(f"Error: Hostname '{hostname}' not found in hostfile")
# sanity checks exit()
if hostname not in host_info:
click.echo(f"Hostname '{hostname}' not found in hostfile") if include_str:
exit() filtered_hosts.append(hostinfo)
for slot in slots: elif exclude_str:
if slot not in host_info[hostname]: filtered_hosts.remove(hostname)
click.echo(f"No slot '{slot}' specified on host '{hostname}'")
return filtered_hosts
# If include string, build the list from here
if include_str:
filtered_hosts[hostname] = slots def get_launch_command(
elif exclude_str: master_addr: str,
for slot in slots: master_port: int,
click.echo(f'- removing {slot} from {hostname}') nproc_per_node: int,
filtered_hosts[hostname].remove(slot) user_script: str,
user_args: List[str],
# User just specified the whole node node_rank: int,
else: num_nodes: int,
hostname = node_config extra_launch_args: str = None,
# sanity check hostname ) -> str:
if hostname not in host_info: """
click.echo(f"Hostname '{hostname}' not found in hostfile") Generate a command for distributed training.
exit()
Args:
if include_str: master_addr (str): the host of the master node
filtered_hosts[hostname] = host_info[hostname] master_port (str): the port of the master node
elif exclude_str: nproc_per_node (str): the number of processes to launch on each node
filtered_hosts[hostname] = [] user_script (str): the user Python file
user_args (str): the arguments for the user script
# Post-processing to remove duplicates and empty nodes node_rank (int): the unique ID for the node
del_keys = [] num_nodes (int): the number of nodes to execute jobs
for hostname in filtered_hosts:
# Remove duplicates Returns:
filtered_hosts[hostname] = _stable_remove_duplicates(filtered_hosts[hostname]) cmd (str): the command the start distributed training
# Remove empty hosts """
if len(filtered_hosts[hostname]) == 0:
del_keys.append(hostname)
# remove unneeded hosts
for name in del_keys:
del filtered_hosts[name]
# Lastly, go over filtered_hosts and convert to a OrderedDict() to ensure
# we map ranks to nodes correctly by maintaining host_info ordering.
ordered_hosts = collections.OrderedDict()
for host in host_info:
if host in filtered_hosts:
ordered_hosts[host] = filtered_hosts[host]
return ordered_hosts def _arg_dict_to_list(arg_dict):
ret = []
for k, v in arg_dict.items():
if v:
ret.append(f'--{k}={v}')
else:
ret.append(f'--{k}')
return ret
if extra_launch_args:
extra_launch_args_dict = dict()
for arg in extra_launch_args.split(','):
if '=' in arg:
k, v = arg.split('=')
extra_launch_args_dict[k] = v
else:
extra_launch_args_dict[arg] = None
extra_launch_args = extra_launch_args_dict
else:
extra_launch_args = dict()
torch_version = version.parse(torch.__version__)
assert torch_version.major == 1
def parse_inclusion_exclusion(device_pool, inclusion, exclusion): if torch_version.minor < 9:
active_devices = collections.OrderedDict() cmd = [
for hostname, slots in device_pool.items(): sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
active_devices[hostname] = list(range(slots)) f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}",
f"--node_rank={node_rank}"
]
else:
# extra launch args for torch distributed launcher with torch >= 1.9
default_torchrun_rdzv_args = dict(rdzv_backend="c10d",
rdzv_endpoint=f"{master_addr}:{master_port}",
rdzv_id="colossalai-default-job")
# update rdzv arguments
for key in default_torchrun_rdzv_args.keys():
if key in extra_launch_args:
value = extra_launch_args.pop(key)
default_torchrun_rdzv_args[key] = value
if torch_version.minor < 10:
cmd = [
sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}",
f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
]
else:
cmd = [
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
]
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
return parse_device_filter(active_devices, include_str=inclusion, exclude_str=exclusion) cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
cmd = ' '.join(cmd)
return cmd
def launch_multi_processes(args): def launch_multi_processes(args: Config) -> None:
""" """
Launch multiple processes on a single node or multiple nodes. Launch multiple processes on a single node or multiple nodes.
The overall logic can be summarized as the pseudo code below: The overall logic can be summarized as the pseudo code below:
if hostfile given: if hostfile given:
hostinfo = parse_hostfile(hostfile) hostinfo = parse_hostfile(hostfile)
hostinfo = include_or_exclude_hosts(hostinfo) hostinfo = include_or_exclude_hosts(hostinfo)
launch_on_multi_nodes(hostinfo) launch_on_multi_nodes(hostinfo)
elif hosts given: elif hosts given:
hostinfo = parse_hosts(hosts) hostinfo = parse_hosts(hosts)
launch_on_multi_nodes(hostinfo) launch_on_multi_nodes(hostinfo)
else: else:
launch_on_current_node() launch_on_current_node()
Args:
args (Config): the arguments taken from command line
""" """
assert isinstance(args, Config) assert isinstance(args, Config)
if args.nproc_per_node is None:
click.echo("--nproc_per_node did not receive any value")
exit()
# cannot accept hosts and hostfile at the same time # cannot accept hosts and hostfile at the same time
if args.host and args.hostfile: if args.host and args.hostfile:
click.echo("Error: hostfile and hosts are mutually exclusive, only one is required") click.echo("Error: hostfile and hosts are mutually exclusive, only one is required")
# check if hostfile is given # check if hostfile is given
if args.hostfile: if args.hostfile:
device_pool = fetch_hostfile(args.hostfile) device_pool = fetch_hostfile(args.hostfile, ssh_port=args.ssh_port)
else: active_device_pool = parse_device_filter(device_pool, args.include, args.exclude)
device_pool = None
# filter and only keep the ones needed
active_devices = None
if device_pool:
active_devices = parse_inclusion_exclusion(device_pool, args.include, args.exclude)
if args.num_nodes > 0: if args.num_nodes > 0:
# only keep the first num_nodes to execute jobs # only keep the first num_nodes to execute jobs
updated_active_devices = collections.OrderedDict() updated_active_device_pool = HostInfoList()
for count, hostname in enumerate(active_devices.keys()): for count, hostinfo in enumerate(active_device_pool):
if args.num_nodes == count: if args.num_nodes == count:
break break
updated_active_devices[hostname] = active_devices[hostname] updated_active_device_pool.append(hostinfo)
active_devices = updated_active_devices active_device_pool = updated_active_device_pool
else:
if args.nproc_per_node > 0: active_device_pool = None
# only keep the first
updated_active_devices = collections.OrderedDict()
for hostname, active_devices in active_devices.items():
if len(active_devices) < args.nproc_per_node:
click.echo(
f"Error: The number of available GPUs on {hostname} is smaller than the argument nproc_per_node"
)
exit()
updated_active_devices[hostname] = active_devices[args.nproc_per_node]
active_devices = updated_active_devices
env = os.environ.copy() env = os.environ.copy()
# use hosts if hostfile is not given # use hosts if hostfile is not given
if args.host and active_devices is None: if args.host and active_device_pool is None:
hostinfo = collections.OrderedDict() active_device_pool = HostInfoList()
host_list = args.host.strip().split(',') host_list = args.host.strip().split(NODE_SEP)
for hostname in host_list: for hostname in host_list:
hostinfo[hostname] = args.nproc_per_node hostinfo = HostInfo(hostname=hostname, port=args.ssh_port)
active_devices = hostinfo active_device_pool.append(hostinfo)
# run on local node if not hosts or hostfile is given if not active_device_pool:
if not active_devices: # run on local node if not hosts or hostfile is given
if args.nproc_per_node == -1 or args.nproc_per_node > torch.cuda.device_count(): # add local node to host info list
nproc_per_node = torch.cuda.device_count() active_device_pool = HostInfoList()
else: localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port)
nproc_per_node = args.nproc_per_node active_device_pool.append(localhost_info)
if torch.__version__ <= "1.9": # launch distributed processes
cmd = [ runner = MultiNodeRunner()
sys.executable, "-u", "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}", curr_path = os.path.abspath('.')
f"--master_addr={args.master_addr}", f"--master_port={args.master_port}"
] + [args.user_script] + args.user_args # collect current path env
else: env = dict()
cmd = [ for k, v in os.environ.items():
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--master_addr={args.master_addr}", # do not support multi-line env var
f"--master_port={args.master_port}" if v and '\n' not in v:
] + [args.user_script] + args.user_args env[k] = v
else:
runner = PDSHRunner(args) # establish remote connection
runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env)
curr_path = os.path.abspath('.')
if 'PYTHONPATH' in env: # execute distributed launching command
env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH'] for node_id, hostinfo in enumerate(active_device_pool):
else: cmd = get_launch_command(master_addr=args.master_addr,
env['PYTHONPATH'] = curr_path master_port=args.master_port,
nproc_per_node=args.nproc_per_node,
cmd = runner.get_cmd(env, active_devices, args) user_script=args.user_script,
user_args=args.user_args,
result = subprocess.Popen(cmd, env=env) node_rank=node_id,
result.wait() num_nodes=len(active_device_pool),
if result.returncode > 0: extra_launch_args=args.extra_launch_args)
sys.exit(result.returncode) runner.send(hostinfo=hostinfo, cmd=cmd)
runner.recv_from_all()
runner.stop_all()
runner.recv_from_all()
...@@ -6,3 +6,4 @@ packaging ...@@ -6,3 +6,4 @@ packaging
pre-commit pre-commit
rich rich
click click
fabric
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment