Unverified Commit cd204a4a authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Distributed] Use barrier instead of sleep in DistDataloader (#2086)

* use barrier instead of sleep

* lint
parent 6d212983
......@@ -2,7 +2,6 @@
"""Multiprocess dataloader for distributed training"""
import multiprocessing as mp
from queue import Queue
import time
import traceback
from .dist_context import get_sampler_pool
......@@ -25,18 +24,16 @@ def call_collate_fn(name, next_data):
DGL_GLOBAL_COLLATE_FNS = {}
DGL_GLOBAL_MP_QUEUES = {}
def init_fn(name, collate_fn, queue):
def init_fn(barrier, name, collate_fn, queue):
"""Initialize setting collate function and mp.Queue in the subprocess"""
global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES
DGL_GLOBAL_MP_QUEUES[name] = queue
DGL_GLOBAL_COLLATE_FNS[name] = collate_fn
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
time.sleep(1)
barrier.wait()
return 1
def cleanup_fn(name):
def cleanup_fn(barrier, name):
"""Clean up the data of a dataloader in the worker process"""
global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES
......@@ -44,7 +41,7 @@ def cleanup_fn(name):
del DGL_GLOBAL_COLLATE_FNS[name]
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
time.sleep(1)
barrier.wait()
return 1
......@@ -52,7 +49,7 @@ def enable_mp_debug():
"""Print multiprocessing debug information. This is only
for debug usage"""
import logging
logger = multiprocessing.log_to_stderr()
logger = mp.log_to_stderr()
logger.setLevel(logging.DEBUG)
DATALOADER_ID = 0
......@@ -122,6 +119,7 @@ class DistDataLoader:
self.current_pos = 0
if self.pool is not None:
self.m = mp.Manager()
self.barrier = self.m.Barrier(self.num_workers)
self.queue = self.m.Queue(maxsize=queue_size)
else:
self.queue = Queue(maxsize=queue_size)
......@@ -145,7 +143,7 @@ class DistDataLoader:
results = []
for _ in range(self.num_workers):
results.append(self.pool.apply_async(
init_fn, args=(self.name, self.collate_fn, self.queue)))
init_fn, args=(self.barrier, self.name, self.collate_fn, self.queue)))
for res in results:
res.get()
......@@ -153,7 +151,7 @@ class DistDataLoader:
if self.pool is not None:
results = []
for _ in range(self.num_workers):
results.append(self.pool.apply_async(cleanup_fn, args=(self.name,)))
results.append(self.pool.apply_async(cleanup_fn, args=(self.barrier, self.name,)))
for res in results:
res.get()
......@@ -162,7 +160,7 @@ class DistDataLoader:
for _ in range(num_reqs):
self._request_next_batch()
if self.recv_idxs < self.expected_idxs:
result = self.queue.get(timeout=9999)
result = self.queue.get(timeout=1800)
self.recv_idxs += 1
self.num_pending -= 1
return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment