Unverified Commit 8bc91414 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Distributed] Fix distributed training hang with multiple samplers (#3169)

Rewrite the multiprocessing worker pool
parent 2ec0493d
"""Initialize the distributed services"""
# pylint: disable=line-too-long
import multiprocessing as mp
import traceback
......@@ -6,6 +7,8 @@ import atexit
import time
import os
import sys
import queue
from enum import Enum
from . import rpc
from .constants import MAX_QUEUE_SIZE
......@@ -18,15 +21,18 @@ SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = 0
INITIALIZED = False
def set_initialized(value=True):
"""Set the initialized state of rpc"""
global INITIALIZED
INITIALIZED = value
def get_sampler_pool():
"""Return the sampler pool and num_workers"""
return SAMPLER_POOL, NUM_SAMPLER_WORKERS
def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_threads):
''' This init function is called in the worker processes.
'''
......@@ -41,6 +47,131 @@ def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_thread
traceback.print_exc()
raise e
class MpCommand(Enum):
"""Enum class for multiprocessing command"""
INIT_RPC = 0 # Not used in the task queue
SET_COLLATE_FN = 1
CALL_BARRIER = 2
DELETE_COLLATE_FN = 3
CALL_COLLATE_FN = 4
CALL_FN_ALL_WORKERS = 5
FINALIZE_POOL = 6
def init_process(rpc_config, mp_contexts):
"""Work loop in the worker"""
try:
_init_rpc(*rpc_config)
keep_polling = True
data_queue, task_queue, barrier = mp_contexts
collate_fn_dict = {}
while keep_polling:
try:
# Follow https://github.com/pytorch/pytorch/blob/d57ce8cf8989c0b737e636d8d7abe16c1f08f70b/torch/utils/data/_utils/worker.py#L260
command, args = task_queue.get(timeout=5)
except queue.Empty:
continue
if command == MpCommand.SET_COLLATE_FN:
dataloader_name, func = args
collate_fn_dict[dataloader_name] = func
elif command == MpCommand.CALL_BARRIER:
barrier.wait()
elif command == MpCommand.DELETE_COLLATE_FN:
dataloader_name, = args
del collate_fn_dict[dataloader_name]
elif command == MpCommand.CALL_COLLATE_FN:
dataloader_name, collate_args = args
data_queue.put(
(dataloader_name, collate_fn_dict[dataloader_name](collate_args)))
elif command == MpCommand.CALL_FN_ALL_WORKERS:
func, func_args = args
func(func_args)
elif command == MpCommand.FINALIZE_POOL:
_exit()
keep_polling = False
else:
raise Exception("Unknown command")
except Exception as e:
traceback.print_exc()
raise e
class CustomPool:
"""Customized worker pool"""
def __init__(self, num_workers, rpc_config):
"""
Customized worker pool init function
"""
ctx = mp.get_context("spawn")
self.num_workers = num_workers
self.queue_size = num_workers * 4
self.result_queue = ctx.Queue(self.queue_size)
self.task_queues = []
self.process_list = []
self.current_proc_id = 0
self.cache_result_dict = {}
self.barrier = ctx.Barrier(num_workers)
for _ in range(num_workers):
task_queue = ctx.Queue(self.queue_size)
self.task_queues.append(task_queue)
proc = ctx.Process(target=init_process, args=(
rpc_config, (self.result_queue, task_queue, self.barrier)))
proc.daemon = True
proc.start()
self.process_list.append(proc)
def set_collate_fn(self, func, dataloader_name):
"""Set collate function in subprocess"""
for i in range(self.num_workers):
self.task_queues[i].put(
(MpCommand.SET_COLLATE_FN, (dataloader_name, func)))
def submit_task(self, dataloader_name, args):
"""Submit task to workers"""
# Round robin
self.task_queues[self.current_proc_id].put(
(MpCommand.CALL_COLLATE_FN, (dataloader_name, args)))
self.current_proc_id = (self.current_proc_id + 1) % self.num_workers
def submit_task_to_all_workers(self, func, args):
"""Submit task to all workers"""
for i in range(self.num_workers):
self.task_queues[i].put(
(MpCommand.CALL_FN_ALL_WORKERS, (func, args)))
def get_result(self, dataloader_name, timeout=1800):
"""Get result from result queue"""
result_dataloader_name, result = self.result_queue.get(timeout=timeout)
assert result_dataloader_name == dataloader_name
return result
def delete_collate_fn(self, dataloader_name):
"""Delete collate function"""
for i in range(self.num_workers):
self.task_queues[i].put(
(MpCommand.DELETE_COLLATE_FN, (dataloader_name, )))
def call_barrier(self):
"""Call barrier at all workers"""
for i in range(self.num_workers):
self.task_queues[i].put(
(MpCommand.CALL_BARRIER, tuple()))
def close(self):
"""Close worker pool"""
for i in range(self.num_workers):
self.task_queues[i].put((MpCommand.FINALIZE_POOL, tuple()), block=False)
time.sleep(0.5) # Fix for early python version
def join(self):
"""Join the close process of worker pool"""
for i in range(self.num_workers):
self.process_list[i].join()
def initialize(ip_config, num_servers=1, num_workers=0,
max_queue_size=MAX_QUEUE_SIZE, net_type='socket',
num_worker_threads=1):
......@@ -84,15 +215,15 @@ def initialize(ip_config, num_servers=1, num_workers=0,
if os.environ.get('DGL_ROLE', 'client') == 'server':
from .dist_graph import DistGraphServer
assert os.environ.get('DGL_SERVER_ID') is not None, \
'Please define DGL_SERVER_ID to run DistGraph server'
'Please define DGL_SERVER_ID to run DistGraph server'
assert os.environ.get('DGL_IP_CONFIG') is not None, \
'Please define DGL_IP_CONFIG to run DistGraph server'
'Please define DGL_IP_CONFIG to run DistGraph server'
assert os.environ.get('DGL_NUM_SERVER') is not None, \
'Please define DGL_NUM_SERVER to run DistGraph server'
'Please define DGL_NUM_SERVER to run DistGraph server'
assert os.environ.get('DGL_NUM_CLIENT') is not None, \
'Please define DGL_NUM_CLIENT to run DistGraph server'
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server'
'Please define DGL_CONF_PATH to run DistGraph server'
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
formats = [f.strip() for f in formats]
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
......@@ -114,46 +245,47 @@ def initialize(ip_config, num_servers=1, num_workers=0,
num_servers = 1
rpc.reset()
ctx = mp.get_context("spawn")
global SAMPLER_POOL
global NUM_SAMPLER_WORKERS
is_standalone = os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone'
is_standalone = os.environ.get(
'DGL_DIST_MODE', 'standalone') == 'standalone'
if num_workers > 0 and not is_standalone:
SAMPLER_POOL = ctx.Pool(num_workers, initializer=_init_rpc,
initargs=(ip_config, num_servers, max_queue_size,
net_type, 'sampler', num_worker_threads))
SAMPLER_POOL = CustomPool(num_workers, (ip_config, num_servers, max_queue_size,
net_type, 'sampler', num_worker_threads))
else:
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = num_workers
if not is_standalone:
assert num_servers is not None and num_servers > 0, \
'The number of servers per machine must be specified with a positive number.'
'The number of servers per machine must be specified with a positive number.'
connect_to_server(ip_config, num_servers, max_queue_size, net_type)
init_role('default')
init_kvstore(ip_config, num_servers, 'default')
def finalize_client():
"""Release resources of this client."""
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
rpc.finalize_sender()
rpc.finalize_receiver()
global INITIALIZED
INITIALIZED = False
def _exit():
exit_client()
time.sleep(1)
def finalize_worker():
"""Finalize workers
Python's multiprocessing pool will not call atexit function when close
"""
global SAMPLER_POOL
if SAMPLER_POOL is not None:
for _ in range(NUM_SAMPLER_WORKERS):
SAMPLER_POOL.apply_async(_exit)
time.sleep(0.1) # This is necessary but I don't know why
SAMPLER_POOL.close()
def join_finalize_worker():
"""join the worker close process"""
global SAMPLER_POOL
......@@ -161,11 +293,13 @@ def join_finalize_worker():
SAMPLER_POOL.join()
SAMPLER_POOL = None
def is_initialized():
"""Is RPC initialized?
"""
return INITIALIZED
def exit_client():
"""Trainer exits
......@@ -177,8 +311,8 @@ def exit_client():
needs to call `exit_client` before calling `initialize` again.
"""
# Only client with rank_0 will send shutdown request to servers.
finalize_worker() # finalize workers should be earilier than barrier, and non-blocking
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
finalize_worker() # finalize workers should be earilier than barrier, and non-blocking
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
rpc.client_barrier()
shutdown_servers()
finalize_client()
......
# pylint: disable=global-variable-undefined, invalid-name
"""Multiprocess dataloader for distributed training"""
import multiprocessing as mp
from queue import Queue
import traceback
from .dist_context import get_sampler_pool
from .. import backend as F
__all__ = ["DistDataLoader"]
def call_collate_fn(name, next_data):
"""Call collate function"""
try:
result = DGL_GLOBAL_COLLATE_FNS[name](next_data)
DGL_GLOBAL_MP_QUEUES[name].put(result)
except Exception as e:
traceback.print_exc()
print(e)
raise e
return 1
DGL_GLOBAL_COLLATE_FNS = {}
DGL_GLOBAL_MP_QUEUES = {}
def init_fn(barrier, name, collate_fn, queue):
"""Initialize setting collate function and mp.Queue in the subprocess"""
global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES
DGL_GLOBAL_MP_QUEUES[name] = queue
DGL_GLOBAL_COLLATE_FNS[name] = collate_fn
barrier.wait()
return 1
def cleanup_fn(barrier, name):
"""Clean up the data of a dataloader in the worker process"""
global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES
del DGL_GLOBAL_MP_QUEUES[name]
del DGL_GLOBAL_COLLATE_FNS[name]
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
barrier.wait()
return 1
def enable_mp_debug():
"""Print multiprocessing debug information. This is only
for debug usage"""
import logging
logger = mp.log_to_stderr()
logger.setLevel(logging.DEBUG)
DATALOADER_ID = 0
class DistDataLoader:
"""DGL customized multiprocessing dataloader.
......@@ -112,16 +66,12 @@ class DistDataLoader:
self.pool, self.num_workers = get_sampler_pool()
if queue_size is None:
queue_size = self.num_workers * 4 if self.num_workers > 0 else 4
self.queue_size = queue_size
self.queue_size = queue_size # prefetch size
self.batch_size = batch_size
self.num_pending = 0
self.collate_fn = collate_fn
self.current_pos = 0
if self.pool is not None:
m = mp.Manager()
self.queue = m.Queue(maxsize=queue_size)
else:
self.queue = Queue(maxsize=queue_size)
self.queue = [] # Only used when pool is None
self.drop_last = drop_last
self.recv_idxs = 0
self.shuffle = shuffle
......@@ -140,34 +90,24 @@ class DistDataLoader:
DATALOADER_ID += 1
if self.pool is not None:
results = []
barrier = m.Barrier(self.num_workers)
for _ in range(self.num_workers):
results.append(self.pool.apply_async(
init_fn, args=(barrier, self.name, self.collate_fn, self.queue)))
for res in results:
res.get()
self.pool.set_collate_fn(self.collate_fn, self.name)
def __del__(self):
# When the process exits, the process pool may have been closed. We should try
# and get the process pool again and see if we need to clean up the process pool.
self.pool, self.num_workers = get_sampler_pool()
if self.pool is not None:
results = []
# Here we need to create the manager and barrier again.
m = mp.Manager()
barrier = m.Barrier(self.num_workers)
for _ in range(self.num_workers):
results.append(self.pool.apply_async(cleanup_fn, args=(barrier, self.name,)))
for res in results:
res.get()
self.pool.delete_collate_fn(self.name)
def __next__(self):
num_reqs = self.queue_size - self.num_pending
if self.pool is None:
num_reqs = 1
else:
num_reqs = self.queue_size - self.num_pending
for _ in range(num_reqs):
self._request_next_batch()
if self.recv_idxs < self.expected_idxs:
result = self.queue.get(timeout=1800)
result = self._get_data_from_result_queue()
self.recv_idxs += 1
self.num_pending -= 1
return result
......@@ -175,6 +115,13 @@ class DistDataLoader:
assert self.num_pending == 0
raise StopIteration
def _get_data_from_result_queue(self, timeout=1800):
if self.pool is None:
ret = self.queue.pop(0)
else:
ret = self.pool.get_result(self.name, timeout=timeout)
return ret
def __iter__(self):
if self.shuffle:
self.data_idx = F.rand_shuffle(self.data_idx)
......@@ -188,10 +135,10 @@ class DistDataLoader:
if next_data is None:
return
elif self.pool is not None:
self.pool.apply_async(call_collate_fn, args=(self.name, next_data, ))
self.pool.submit_task(self.name, next_data)
else:
result = self.collate_fn(next_data)
self.queue.put(result)
self.queue.append(result)
self.num_pending += 1
def _next_data(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment