"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "1db71cfefec88e05299c6d29464f7e1737adb4b9"
Unverified Commit 8bc91414 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Distributed] Fix distributed training hang with multiple samplers (#3169)

Rewrite the multiprocessing worker pool
parent 2ec0493d
"""Initialize the distributed services""" """Initialize the distributed services"""
# pylint: disable=line-too-long
import multiprocessing as mp import multiprocessing as mp
import traceback import traceback
...@@ -6,6 +7,8 @@ import atexit ...@@ -6,6 +7,8 @@ import atexit
import time import time
import os import os
import sys import sys
import queue
from enum import Enum
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE from .constants import MAX_QUEUE_SIZE
...@@ -18,15 +21,18 @@ SAMPLER_POOL = None ...@@ -18,15 +21,18 @@ SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = 0 NUM_SAMPLER_WORKERS = 0
INITIALIZED = False INITIALIZED = False
def set_initialized(value=True): def set_initialized(value=True):
"""Set the initialized state of rpc""" """Set the initialized state of rpc"""
global INITIALIZED global INITIALIZED
INITIALIZED = value INITIALIZED = value
def get_sampler_pool(): def get_sampler_pool():
"""Return the sampler pool and num_workers""" """Return the sampler pool and num_workers"""
return SAMPLER_POOL, NUM_SAMPLER_WORKERS return SAMPLER_POOL, NUM_SAMPLER_WORKERS
def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_threads): def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_threads):
''' This init function is called in the worker processes. ''' This init function is called in the worker processes.
''' '''
...@@ -41,6 +47,131 @@ def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_thread ...@@ -41,6 +47,131 @@ def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_thread
traceback.print_exc() traceback.print_exc()
raise e raise e
class MpCommand(Enum):
"""Enum class for multiprocessing command"""
INIT_RPC = 0 # Not used in the task queue
SET_COLLATE_FN = 1
CALL_BARRIER = 2
DELETE_COLLATE_FN = 3
CALL_COLLATE_FN = 4
CALL_FN_ALL_WORKERS = 5
FINALIZE_POOL = 6
def init_process(rpc_config, mp_contexts):
"""Work loop in the worker"""
try:
_init_rpc(*rpc_config)
keep_polling = True
data_queue, task_queue, barrier = mp_contexts
collate_fn_dict = {}
while keep_polling:
try:
# Follow https://github.com/pytorch/pytorch/blob/d57ce8cf8989c0b737e636d8d7abe16c1f08f70b/torch/utils/data/_utils/worker.py#L260
command, args = task_queue.get(timeout=5)
except queue.Empty:
continue
if command == MpCommand.SET_COLLATE_FN:
dataloader_name, func = args
collate_fn_dict[dataloader_name] = func
elif command == MpCommand.CALL_BARRIER:
barrier.wait()
elif command == MpCommand.DELETE_COLLATE_FN:
dataloader_name, = args
del collate_fn_dict[dataloader_name]
elif command == MpCommand.CALL_COLLATE_FN:
dataloader_name, collate_args = args
data_queue.put(
(dataloader_name, collate_fn_dict[dataloader_name](collate_args)))
elif command == MpCommand.CALL_FN_ALL_WORKERS:
func, func_args = args
func(func_args)
elif command == MpCommand.FINALIZE_POOL:
_exit()
keep_polling = False
else:
raise Exception("Unknown command")
except Exception as e:
traceback.print_exc()
raise e
class CustomPool:
"""Customized worker pool"""
def __init__(self, num_workers, rpc_config):
"""
Customized worker pool init function
"""
ctx = mp.get_context("spawn")
self.num_workers = num_workers
self.queue_size = num_workers * 4
self.result_queue = ctx.Queue(self.queue_size)
self.task_queues = []
self.process_list = []
self.current_proc_id = 0
self.cache_result_dict = {}
self.barrier = ctx.Barrier(num_workers)
for _ in range(num_workers):
task_queue = ctx.Queue(self.queue_size)
self.task_queues.append(task_queue)
proc = ctx.Process(target=init_process, args=(
rpc_config, (self.result_queue, task_queue, self.barrier)))
proc.daemon = True
proc.start()
self.process_list.append(proc)
def set_collate_fn(self, func, dataloader_name):
"""Set collate function in subprocess"""
for i in range(self.num_workers):
self.task_queues[i].put(
(MpCommand.SET_COLLATE_FN, (dataloader_name, func)))
def submit_task(self, dataloader_name, args):
"""Submit task to workers"""
# Round robin
self.task_queues[self.current_proc_id].put(
(MpCommand.CALL_COLLATE_FN, (dataloader_name, args)))
self.current_proc_id = (self.current_proc_id + 1) % self.num_workers
def submit_task_to_all_workers(self, func, args):
"""Submit task to all workers"""
for i in range(self.num_workers):
self.task_queues[i].put(
(MpCommand.CALL_FN_ALL_WORKERS, (func, args)))
def get_result(self, dataloader_name, timeout=1800):
"""Get result from result queue"""
result_dataloader_name, result = self.result_queue.get(timeout=timeout)
assert result_dataloader_name == dataloader_name
return result
def delete_collate_fn(self, dataloader_name):
"""Delete collate function"""
for i in range(self.num_workers):
self.task_queues[i].put(
(MpCommand.DELETE_COLLATE_FN, (dataloader_name, )))
def call_barrier(self):
"""Call barrier at all workers"""
for i in range(self.num_workers):
self.task_queues[i].put(
(MpCommand.CALL_BARRIER, tuple()))
def close(self):
"""Close worker pool"""
for i in range(self.num_workers):
self.task_queues[i].put((MpCommand.FINALIZE_POOL, tuple()), block=False)
time.sleep(0.5) # Fix for early python version
def join(self):
"""Join the close process of worker pool"""
for i in range(self.num_workers):
self.process_list[i].join()
def initialize(ip_config, num_servers=1, num_workers=0, def initialize(ip_config, num_servers=1, num_workers=0,
max_queue_size=MAX_QUEUE_SIZE, net_type='socket', max_queue_size=MAX_QUEUE_SIZE, net_type='socket',
num_worker_threads=1): num_worker_threads=1):
...@@ -84,15 +215,15 @@ def initialize(ip_config, num_servers=1, num_workers=0, ...@@ -84,15 +215,15 @@ def initialize(ip_config, num_servers=1, num_workers=0,
if os.environ.get('DGL_ROLE', 'client') == 'server': if os.environ.get('DGL_ROLE', 'client') == 'server':
from .dist_graph import DistGraphServer from .dist_graph import DistGraphServer
assert os.environ.get('DGL_SERVER_ID') is not None, \ assert os.environ.get('DGL_SERVER_ID') is not None, \
'Please define DGL_SERVER_ID to run DistGraph server' 'Please define DGL_SERVER_ID to run DistGraph server'
assert os.environ.get('DGL_IP_CONFIG') is not None, \ assert os.environ.get('DGL_IP_CONFIG') is not None, \
'Please define DGL_IP_CONFIG to run DistGraph server' 'Please define DGL_IP_CONFIG to run DistGraph server'
assert os.environ.get('DGL_NUM_SERVER') is not None, \ assert os.environ.get('DGL_NUM_SERVER') is not None, \
'Please define DGL_NUM_SERVER to run DistGraph server' 'Please define DGL_NUM_SERVER to run DistGraph server'
assert os.environ.get('DGL_NUM_CLIENT') is not None, \ assert os.environ.get('DGL_NUM_CLIENT') is not None, \
'Please define DGL_NUM_CLIENT to run DistGraph server' 'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \ assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server' 'Please define DGL_CONF_PATH to run DistGraph server'
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',') formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
formats = [f.strip() for f in formats] formats = [f.strip() for f in formats]
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')), serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
...@@ -114,46 +245,47 @@ def initialize(ip_config, num_servers=1, num_workers=0, ...@@ -114,46 +245,47 @@ def initialize(ip_config, num_servers=1, num_workers=0,
num_servers = 1 num_servers = 1
rpc.reset() rpc.reset()
ctx = mp.get_context("spawn")
global SAMPLER_POOL global SAMPLER_POOL
global NUM_SAMPLER_WORKERS global NUM_SAMPLER_WORKERS
is_standalone = os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone' is_standalone = os.environ.get(
'DGL_DIST_MODE', 'standalone') == 'standalone'
if num_workers > 0 and not is_standalone: if num_workers > 0 and not is_standalone:
SAMPLER_POOL = ctx.Pool(num_workers, initializer=_init_rpc, SAMPLER_POOL = CustomPool(num_workers, (ip_config, num_servers, max_queue_size,
initargs=(ip_config, num_servers, max_queue_size, net_type, 'sampler', num_worker_threads))
net_type, 'sampler', num_worker_threads))
else: else:
SAMPLER_POOL = None SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = num_workers NUM_SAMPLER_WORKERS = num_workers
if not is_standalone: if not is_standalone:
assert num_servers is not None and num_servers > 0, \ assert num_servers is not None and num_servers > 0, \
'The number of servers per machine must be specified with a positive number.' 'The number of servers per machine must be specified with a positive number.'
connect_to_server(ip_config, num_servers, max_queue_size, net_type) connect_to_server(ip_config, num_servers, max_queue_size, net_type)
init_role('default') init_role('default')
init_kvstore(ip_config, num_servers, 'default') init_kvstore(ip_config, num_servers, 'default')
def finalize_client(): def finalize_client():
"""Release resources of this client.""" """Release resources of this client."""
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone': if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
rpc.finalize_sender() rpc.finalize_sender()
rpc.finalize_receiver() rpc.finalize_receiver()
global INITIALIZED global INITIALIZED
INITIALIZED = False INITIALIZED = False
def _exit(): def _exit():
exit_client() exit_client()
time.sleep(1) time.sleep(1)
def finalize_worker(): def finalize_worker():
"""Finalize workers """Finalize workers
Python's multiprocessing pool will not call atexit function when close Python's multiprocessing pool will not call atexit function when close
""" """
global SAMPLER_POOL
if SAMPLER_POOL is not None: if SAMPLER_POOL is not None:
for _ in range(NUM_SAMPLER_WORKERS):
SAMPLER_POOL.apply_async(_exit)
time.sleep(0.1) # This is necessary but I don't know why
SAMPLER_POOL.close() SAMPLER_POOL.close()
def join_finalize_worker(): def join_finalize_worker():
"""join the worker close process""" """join the worker close process"""
global SAMPLER_POOL global SAMPLER_POOL
...@@ -161,11 +293,13 @@ def join_finalize_worker(): ...@@ -161,11 +293,13 @@ def join_finalize_worker():
SAMPLER_POOL.join() SAMPLER_POOL.join()
SAMPLER_POOL = None SAMPLER_POOL = None
def is_initialized(): def is_initialized():
"""Is RPC initialized? """Is RPC initialized?
""" """
return INITIALIZED return INITIALIZED
def exit_client(): def exit_client():
"""Trainer exits """Trainer exits
...@@ -177,8 +311,8 @@ def exit_client(): ...@@ -177,8 +311,8 @@ def exit_client():
needs to call `exit_client` before calling `initialize` again. needs to call `exit_client` before calling `initialize` again.
""" """
# Only client with rank_0 will send shutdown request to servers. # Only client with rank_0 will send shutdown request to servers.
finalize_worker() # finalize workers should be earilier than barrier, and non-blocking finalize_worker() # finalize workers should be earilier than barrier, and non-blocking
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone': if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
rpc.client_barrier() rpc.client_barrier()
shutdown_servers() shutdown_servers()
finalize_client() finalize_client()
......
# pylint: disable=global-variable-undefined, invalid-name # pylint: disable=global-variable-undefined, invalid-name
"""Multiprocess dataloader for distributed training""" """Multiprocess dataloader for distributed training"""
import multiprocessing as mp
from queue import Queue
import traceback
from .dist_context import get_sampler_pool from .dist_context import get_sampler_pool
from .. import backend as F from .. import backend as F
__all__ = ["DistDataLoader"] __all__ = ["DistDataLoader"]
def call_collate_fn(name, next_data):
"""Call collate function"""
try:
result = DGL_GLOBAL_COLLATE_FNS[name](next_data)
DGL_GLOBAL_MP_QUEUES[name].put(result)
except Exception as e:
traceback.print_exc()
print(e)
raise e
return 1
DGL_GLOBAL_COLLATE_FNS = {}
DGL_GLOBAL_MP_QUEUES = {}
def init_fn(barrier, name, collate_fn, queue):
"""Initialize setting collate function and mp.Queue in the subprocess"""
global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES
DGL_GLOBAL_MP_QUEUES[name] = queue
DGL_GLOBAL_COLLATE_FNS[name] = collate_fn
barrier.wait()
return 1
def cleanup_fn(barrier, name):
"""Clean up the data of a dataloader in the worker process"""
global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES
del DGL_GLOBAL_MP_QUEUES[name]
del DGL_GLOBAL_COLLATE_FNS[name]
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
barrier.wait()
return 1
def enable_mp_debug():
"""Print multiprocessing debug information. This is only
for debug usage"""
import logging
logger = mp.log_to_stderr()
logger.setLevel(logging.DEBUG)
DATALOADER_ID = 0 DATALOADER_ID = 0
class DistDataLoader: class DistDataLoader:
"""DGL customized multiprocessing dataloader. """DGL customized multiprocessing dataloader.
...@@ -112,16 +66,12 @@ class DistDataLoader: ...@@ -112,16 +66,12 @@ class DistDataLoader:
self.pool, self.num_workers = get_sampler_pool() self.pool, self.num_workers = get_sampler_pool()
if queue_size is None: if queue_size is None:
queue_size = self.num_workers * 4 if self.num_workers > 0 else 4 queue_size = self.num_workers * 4 if self.num_workers > 0 else 4
self.queue_size = queue_size self.queue_size = queue_size # prefetch size
self.batch_size = batch_size self.batch_size = batch_size
self.num_pending = 0 self.num_pending = 0
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.current_pos = 0 self.current_pos = 0
if self.pool is not None: self.queue = [] # Only used when pool is None
m = mp.Manager()
self.queue = m.Queue(maxsize=queue_size)
else:
self.queue = Queue(maxsize=queue_size)
self.drop_last = drop_last self.drop_last = drop_last
self.recv_idxs = 0 self.recv_idxs = 0
self.shuffle = shuffle self.shuffle = shuffle
...@@ -140,34 +90,24 @@ class DistDataLoader: ...@@ -140,34 +90,24 @@ class DistDataLoader:
DATALOADER_ID += 1 DATALOADER_ID += 1
if self.pool is not None: if self.pool is not None:
results = [] self.pool.set_collate_fn(self.collate_fn, self.name)
barrier = m.Barrier(self.num_workers)
for _ in range(self.num_workers):
results.append(self.pool.apply_async(
init_fn, args=(barrier, self.name, self.collate_fn, self.queue)))
for res in results:
res.get()
def __del__(self): def __del__(self):
# When the process exits, the process pool may have been closed. We should try # When the process exits, the process pool may have been closed. We should try
# and get the process pool again and see if we need to clean up the process pool. # and get the process pool again and see if we need to clean up the process pool.
self.pool, self.num_workers = get_sampler_pool() self.pool, self.num_workers = get_sampler_pool()
if self.pool is not None: if self.pool is not None:
results = [] self.pool.delete_collate_fn(self.name)
# Here we need to create the manager and barrier again.
m = mp.Manager()
barrier = m.Barrier(self.num_workers)
for _ in range(self.num_workers):
results.append(self.pool.apply_async(cleanup_fn, args=(barrier, self.name,)))
for res in results:
res.get()
def __next__(self): def __next__(self):
num_reqs = self.queue_size - self.num_pending if self.pool is None:
num_reqs = 1
else:
num_reqs = self.queue_size - self.num_pending
for _ in range(num_reqs): for _ in range(num_reqs):
self._request_next_batch() self._request_next_batch()
if self.recv_idxs < self.expected_idxs: if self.recv_idxs < self.expected_idxs:
result = self.queue.get(timeout=1800) result = self._get_data_from_result_queue()
self.recv_idxs += 1 self.recv_idxs += 1
self.num_pending -= 1 self.num_pending -= 1
return result return result
...@@ -175,6 +115,13 @@ class DistDataLoader: ...@@ -175,6 +115,13 @@ class DistDataLoader:
assert self.num_pending == 0 assert self.num_pending == 0
raise StopIteration raise StopIteration
def _get_data_from_result_queue(self, timeout=1800):
if self.pool is None:
ret = self.queue.pop(0)
else:
ret = self.pool.get_result(self.name, timeout=timeout)
return ret
def __iter__(self): def __iter__(self):
if self.shuffle: if self.shuffle:
self.data_idx = F.rand_shuffle(self.data_idx) self.data_idx = F.rand_shuffle(self.data_idx)
...@@ -188,10 +135,10 @@ class DistDataLoader: ...@@ -188,10 +135,10 @@ class DistDataLoader:
if next_data is None: if next_data is None:
return return
elif self.pool is not None: elif self.pool is not None:
self.pool.apply_async(call_collate_fn, args=(self.name, next_data, )) self.pool.submit_task(self.name, next_data)
else: else:
result = self.collate_fn(next_data) result = self.collate_fn(next_data)
self.queue.put(result) self.queue.append(result)
self.num_pending += 1 self.num_pending += 1
def _next_data(self): def _next_data(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment