Unverified Commit fcd8ed9a authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] Launch Long Live Servers and Multiple Client Groups (#3688)

* enable to launch multiple client groups sequentially

* launch simultaneously is enabled

* refine docstring

* revert unnecessary change

* [DOC] add doc for long live server

* refine

* refine doc

* refine doc
parent 738e8318
...@@ -50,7 +50,7 @@ Below shows an example of launching a distributed training job in a cluster. ...@@ -50,7 +50,7 @@ Below shows an example of launching a distributed training job in a cluster.
--num_servers 1 \ --num_servers 1 \
--part_config data/ogb-product.json \ --part_config data/ogb-product.json \
--ip_config ip_config.txt \ --ip_config ip_config.txt \
"python3 code/train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 5 --batch-size 1000 --lr 0.1 --num_workers 4" "python3 code/train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 5 --batch-size 1000 --lr 0.1"
The configuration file *ip_config.txt* contains the IP addresses of the machines in a cluster. The configuration file *ip_config.txt* contains the IP addresses of the machines in a cluster.
A typical example of *ip_config.txt* is as follows: A typical example of *ip_config.txt* is as follows:
...@@ -75,3 +75,48 @@ The launch script creates a specified number of training jobs (``--num_trainers` ...@@ -75,3 +75,48 @@ The launch script creates a specified number of training jobs (``--num_trainers`
In addition, a user needs to specify the number of sampler processes for each trainer In addition, a user needs to specify the number of sampler processes for each trainer
(``--num_samplers``). The number of sampler processes has to match with the number of worker processes (``--num_samplers``). The number of sampler processes has to match with the number of worker processes
specified in :func:`~dgl.distributed.initialize`. specified in :func:`~dgl.distributed.initialize`.
It is common that users may want to try different models or training configurations
against the same graph data. To avoid repetitively loading the same graph data, DGL
allows users to launch a persistent graph server to be shared across multiple training
jobs. A persistent graph server will stay alive even all training workers have
finished and exited. Below shows an example of launching a persistent graph server:
We first launch the graph server together with the first group of training workers.
.. code:: none
python3 tools/launch.py \
--workspace ~graphsage/ \
--num_trainers 2 \
--num_samplers 4 \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
--keep_alive \
--server_name long_live \
"python3 code/train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 5 --batch-size 1000 --lr 0.1"
Pay attention to the ``--keep_alive`` option, which indicates the server should
stay alive after workers have finished. ``--server_name`` is the given name of
the server which will be referred when launching new training jobs.
Launch another group of distributed training job and connect to the existing persistent server.
.. code:: none
python3 tools/launch.py \
--workspace ~graphsage/ \
--num_trainers 2 \
--num_samplers 4 \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
--server_name long_live \
"python3 code/train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 5 --batch-size 1000 --lr 0.1"
.. note::
All the arguments for ``launch.py`` should be kept same as previous launch. And below
arguments for specific training script should be kept same as well: ``--graph-name``,
``--ip_config``. The rest arguments such as ``--num-epochs``, ``--batch-size`` and so
on are free to change.
...@@ -228,7 +228,7 @@ def initialize(ip_config, num_servers=1, num_workers=0, ...@@ -228,7 +228,7 @@ def initialize(ip_config, num_servers=1, num_workers=0,
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',') formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
formats = [f.strip() for f in formats] formats = [f.strip() for f in formats]
rpc.reset() rpc.reset()
keep_alive = os.environ.get('DGL_KEEP_ALIVE') is not None keep_alive = bool(int(os.environ.get('DGL_KEEP_ALIVE', 0)))
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')), serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'), os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')), int(os.environ.get('DGL_NUM_SERVER')),
...@@ -322,6 +322,8 @@ def exit_client(): ...@@ -322,6 +322,8 @@ def exit_client():
needs to call `exit_client` before calling `initialize` again. needs to call `exit_client` before calling `initialize` again.
""" """
# Only client with rank_0 will send shutdown request to servers. # Only client with rank_0 will send shutdown request to servers.
print("Client[{}] in group[{}] is exiting...".format(
rpc.get_rank(), rpc.get_group_id()))
finalize_worker() # finalize workers should be earilier than barrier, and non-blocking finalize_worker() # finalize workers should be earilier than barrier, and non-blocking
# collect data such as DistTensor before exit # collect data such as DistTensor before exit
gc.collect() gc.collect()
......
...@@ -15,7 +15,7 @@ from .. import backend as F ...@@ -15,7 +15,7 @@ from .. import backend as F
__all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \ 'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \
'receiver_wait', 'connect_receiver', 'read_ip_config', \ 'receiver_wait', 'connect_receiver', 'read_ip_config', 'get_group_id', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \ 'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \ 'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \ 'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
......
...@@ -14,8 +14,6 @@ from functools import partial ...@@ -14,8 +14,6 @@ from functools import partial
from threading import Thread from threading import Thread
from typing import Optional from typing import Optional
DEFAULT_PORT = 30050
def cleanup_proc(get_all_remote_pids, conn): def cleanup_proc(get_all_remote_pids, conn):
'''This process tries to clean up the remote training tasks. '''This process tries to clean up the remote training tasks.
''' '''
...@@ -271,6 +269,7 @@ def construct_dgl_server_env_vars( ...@@ -271,6 +269,7 @@ def construct_dgl_server_env_vars(
ip_config: str, ip_config: str,
num_servers: int, num_servers: int,
graph_format: str, graph_format: str,
keep_alive: bool,
pythonpath: Optional[str] = "", pythonpath: Optional[str] = "",
) -> str: ) -> str:
"""Constructs the DGL server-specific env vars string that are required for DGL code to behave in the correct """Constructs the DGL server-specific env vars string that are required for DGL code to behave in the correct
...@@ -287,6 +286,8 @@ def construct_dgl_server_env_vars( ...@@ -287,6 +286,8 @@ def construct_dgl_server_env_vars(
Relative path to workspace. Relative path to workspace.
num_servers: num_servers:
graph_format: graph_format:
keep_alive:
Whether to keep server alive when clients exit
pythonpath: Optional. If given, this will pass this as PYTHONPATH. pythonpath: Optional. If given, this will pass this as PYTHONPATH.
Returns: Returns:
...@@ -302,6 +303,7 @@ def construct_dgl_server_env_vars( ...@@ -302,6 +303,7 @@ def construct_dgl_server_env_vars(
"DGL_IP_CONFIG={DGL_IP_CONFIG} " "DGL_IP_CONFIG={DGL_IP_CONFIG} "
"DGL_NUM_SERVER={DGL_NUM_SERVER} " "DGL_NUM_SERVER={DGL_NUM_SERVER} "
"DGL_GRAPH_FORMAT={DGL_GRAPH_FORMAT} " "DGL_GRAPH_FORMAT={DGL_GRAPH_FORMAT} "
"DGL_KEEP_ALIVE={DGL_KEEP_ALIVE} "
"{suffix_optional_envvars}" "{suffix_optional_envvars}"
) )
suffix_optional_envvars = "" suffix_optional_envvars = ""
...@@ -316,6 +318,7 @@ def construct_dgl_server_env_vars( ...@@ -316,6 +318,7 @@ def construct_dgl_server_env_vars(
DGL_IP_CONFIG=ip_config, DGL_IP_CONFIG=ip_config,
DGL_NUM_SERVER=num_servers, DGL_NUM_SERVER=num_servers,
DGL_GRAPH_FORMAT=graph_format, DGL_GRAPH_FORMAT=graph_format,
DGL_KEEP_ALIVE=int(keep_alive),
suffix_optional_envvars=suffix_optional_envvars, suffix_optional_envvars=suffix_optional_envvars,
) )
...@@ -328,6 +331,7 @@ def construct_dgl_client_env_vars( ...@@ -328,6 +331,7 @@ def construct_dgl_client_env_vars(
num_servers: int, num_servers: int,
graph_format: str, graph_format: str,
num_omp_threads: int, num_omp_threads: int,
group_id: int,
pythonpath: Optional[str] = "", pythonpath: Optional[str] = "",
) -> str: ) -> str:
"""Constructs the DGL client-specific env vars string that are required for DGL code to behave in the correct """Constructs the DGL client-specific env vars string that are required for DGL code to behave in the correct
...@@ -344,6 +348,8 @@ def construct_dgl_client_env_vars( ...@@ -344,6 +348,8 @@ def construct_dgl_client_env_vars(
num_servers: num_servers:
graph_format: graph_format:
num_omp_threads: num_omp_threads:
group_id:
Used in client processes to indicate which group it belongs to.
pythonpath: Optional. If given, this will pass this as PYTHONPATH. pythonpath: Optional. If given, this will pass this as PYTHONPATH.
Returns: Returns:
...@@ -360,6 +366,7 @@ def construct_dgl_client_env_vars( ...@@ -360,6 +366,7 @@ def construct_dgl_client_env_vars(
"DGL_NUM_SERVER={DGL_NUM_SERVER} " "DGL_NUM_SERVER={DGL_NUM_SERVER} "
"DGL_GRAPH_FORMAT={DGL_GRAPH_FORMAT} " "DGL_GRAPH_FORMAT={DGL_GRAPH_FORMAT} "
"OMP_NUM_THREADS={OMP_NUM_THREADS} " "OMP_NUM_THREADS={OMP_NUM_THREADS} "
"DGL_GROUP_ID={DGL_GROUP_ID} "
"{suffix_optional_envvars}" "{suffix_optional_envvars}"
) )
# append optional additional env-vars # append optional additional env-vars
...@@ -376,6 +383,7 @@ def construct_dgl_client_env_vars( ...@@ -376,6 +383,7 @@ def construct_dgl_client_env_vars(
DGL_NUM_SERVER=num_servers, DGL_NUM_SERVER=num_servers,
DGL_GRAPH_FORMAT=graph_format, DGL_GRAPH_FORMAT=graph_format,
OMP_NUM_THREADS=num_omp_threads, OMP_NUM_THREADS=num_omp_threads,
DGL_GROUP_ID=group_id,
suffix_optional_envvars=suffix_optional_envvars, suffix_optional_envvars=suffix_optional_envvars,
) )
...@@ -424,6 +432,72 @@ def wrap_cmd_with_extra_envvars(cmd: str, env_vars: list) -> str: ...@@ -424,6 +432,72 @@ def wrap_cmd_with_extra_envvars(cmd: str, env_vars: list) -> str:
env_vars = " ".join(env_vars) env_vars = " ".join(env_vars)
return wrap_cmd_with_local_envvars(cmd, env_vars) return wrap_cmd_with_local_envvars(cmd, env_vars)
g_monitor_file = None
g_group_id = 0
def has_alive_servers(args):
"""Check whether there exists alive servers.
For each group of long live servers, a monitor file named
'dgl_dist_monitor_{args.server_name}' is created under '/tmp/' directory.
We check the existence of this monitor file to determine whether to
launch new servers or utilize the existing alive ones. If there
exist alive servers, we obtain availale group ID from the monitor
file which could be used in current client groups.
Returns
-------
bool
indicates whether there exists alive servers.
"""
if args.server_name is None:
return False
global g_monitor_file
global g_group_id
monitor_file = '/tmp/dgl_dist_monitor_' + args.server_name
from filelock import FileLock
lock = FileLock(monitor_file + '.lock')
with lock:
next_group_id = None
ret = os.path.exists(monitor_file)
if ret:
print("Monitor file for alive servers already exist: {}.".format(monitor_file))
lines = [line.rstrip('\n') for line in open(monitor_file)]
g_group_id = int(lines[0])
next_group_id = g_group_id + 1
if not ret and args.keep_alive:
next_group_id = 1
print("Monitor file for alive servers is created: {}.".format(monitor_file))
g_monitor_file = monitor_file
if next_group_id is not None:
with open(monitor_file, 'w') as f:
f.write(str(next_group_id))
return ret
def clean_alive_servers():
"""Remove keep alive related files"""
global g_monitor_file
try:
if g_monitor_file is not None:
os.remove(g_monitor_file)
os.remove(g_monitor_file + '.lock')
print("Monitor file for alive servers is removed: {}.".format(g_monitor_file))
except:
print("Failed to delete monitor file for alive servers: {}.".format(g_monitor_file))
def get_available_port(ip):
"""Get available port with specified ip."""
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
for port in range(1234, 65535):
try:
sock.connect((ip, port))
except:
return port
raise RuntimeError("Failed to get available port for ip~{}".format(ip))
def submit_jobs(args, udf_command): def submit_jobs(args, udf_command):
"""Submit distributed jobs (server and client processes) via ssh""" """Submit distributed jobs (server and client processes) via ssh"""
hosts = [] hosts = []
...@@ -441,7 +515,7 @@ def submit_jobs(args, udf_command): ...@@ -441,7 +515,7 @@ def submit_jobs(args, udf_command):
hosts.append((ip, port)) hosts.append((ip, port))
elif len(result) == 1: elif len(result) == 1:
ip = result[0] ip = result[0]
port = DEFAULT_PORT port = get_available_port(ip)
hosts.append((ip, port)) hosts.append((ip, port))
else: else:
raise RuntimeError("Format error of ip_config.") raise RuntimeError("Format error of ip_config.")
...@@ -457,23 +531,27 @@ def submit_jobs(args, udf_command): ...@@ -457,23 +531,27 @@ def submit_jobs(args, udf_command):
tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts) tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts)
# launch server tasks # launch server tasks
server_env_vars = construct_dgl_server_env_vars( if not has_alive_servers(args):
num_samplers=args.num_samplers, server_env_vars = construct_dgl_server_env_vars(
num_server_threads=args.num_server_threads, num_samplers=args.num_samplers,
tot_num_clients=tot_num_clients, num_server_threads=args.num_server_threads,
part_config=args.part_config, tot_num_clients=tot_num_clients,
ip_config=args.ip_config, part_config=args.part_config,
num_servers=args.num_servers, ip_config=args.ip_config,
graph_format=args.graph_format, num_servers=args.num_servers,
pythonpath=os.environ.get("PYTHONPATH", ""), graph_format=args.graph_format,
) keep_alive=args.keep_alive,
for i in range(len(hosts) * server_count_per_machine): pythonpath=os.environ.get("PYTHONPATH", ""),
ip, _ = hosts[int(i / server_count_per_machine)] )
server_env_vars_cur = f"{server_env_vars} DGL_SERVER_ID={i}" for i in range(len(hosts) * server_count_per_machine):
cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur) ip, _ = hosts[int(i / server_count_per_machine)]
cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd server_env_vars_cur = f"{server_env_vars} DGL_SERVER_ID={i}"
cmd = 'cd ' + str(args.workspace) + '; ' + cmd cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur)
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username)) cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
else:
print(f"Use running server {args.server_name}.")
# launch client tasks # launch client tasks
client_env_vars = construct_dgl_client_env_vars( client_env_vars = construct_dgl_client_env_vars(
...@@ -484,6 +562,7 @@ def submit_jobs(args, udf_command): ...@@ -484,6 +562,7 @@ def submit_jobs(args, udf_command):
num_servers=args.num_servers, num_servers=args.num_servers,
graph_format=args.graph_format, graph_format=args.graph_format,
num_omp_threads=os.environ.get("OMP_NUM_THREADS", str(args.num_omp_threads)), num_omp_threads=os.environ.get("OMP_NUM_THREADS", str(args.num_omp_threads)),
group_id=g_group_id,
pythonpath=os.environ.get("PYTHONPATH", ""), pythonpath=os.environ.get("PYTHONPATH", ""),
) )
...@@ -496,7 +575,7 @@ def submit_jobs(args, udf_command): ...@@ -496,7 +575,7 @@ def submit_jobs(args, udf_command):
num_nodes=len(hosts), num_nodes=len(hosts),
node_rank=node_id, node_rank=node_id,
master_addr=hosts[0][0], master_addr=hosts[0][0],
master_port=1234, master_port=get_available_port(hosts[0][0]),
) )
cmd = wrap_cmd_with_local_envvars(torch_dist_udf_command, client_env_vars) cmd = wrap_cmd_with_local_envvars(torch_dist_udf_command, client_env_vars)
cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd
...@@ -513,6 +592,7 @@ def submit_jobs(args, udf_command): ...@@ -513,6 +592,7 @@ def submit_jobs(args, udf_command):
logging.info('Stop launcher') logging.info('Stop launcher')
# We need to tell the cleanup process to kill remote training jobs. # We need to tell the cleanup process to kill remote training jobs.
conn2.send('cleanup') conn2.send('cleanup')
clean_alive_servers()
sys.exit(0) sys.exit(0)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
...@@ -560,7 +640,13 @@ def main(): ...@@ -560,7 +640,13 @@ def main():
help='Extra environment parameters need to be set. For example, \ help='Extra environment parameters need to be set. For example, \
you can set the LD_LIBRARY_PATH and NCCL_DEBUG by adding: \ you can set the LD_LIBRARY_PATH and NCCL_DEBUG by adding: \
--extra_envs LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH NCCL_DEBUG=INFO ') --extra_envs LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH NCCL_DEBUG=INFO ')
parser.add_argument('--keep_alive', action='store_true', help='Servers keep alive when clients exit')
parser.add_argument('--server_name', type=str,
help='Used to check whether there exist alive servers')
args, udf_command = parser.parse_known_args() args, udf_command = parser.parse_known_args()
if args.keep_alive:
assert args.server_name is not None, "Server name is required if '--keep_alive' is enabled."
print("Servers will keep alive even clients exit...")
assert len(udf_command) == 1, 'Please provide user command line.' assert len(udf_command) == 1, 'Please provide user command line.'
assert args.num_trainers is not None and args.num_trainers > 0, \ assert args.num_trainers is not None and args.num_trainers > 0, \
'--num_trainers must be a positive number.' '--num_trainers must be a positive number.'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment