"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "d876680a36fcbd994b9e6b6c42253bd47713b276"
Unverified Commit cd204a4a authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Distributed] Use barrier instead of sleep in DistDataloader (#2086)

* use barrier instead of sleep

* lint
parent 6d212983
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
"""Multiprocess dataloader for distributed training""" """Multiprocess dataloader for distributed training"""
import multiprocessing as mp import multiprocessing as mp
from queue import Queue from queue import Queue
import time
import traceback import traceback
from .dist_context import get_sampler_pool from .dist_context import get_sampler_pool
...@@ -25,18 +24,16 @@ def call_collate_fn(name, next_data): ...@@ -25,18 +24,16 @@ def call_collate_fn(name, next_data):
DGL_GLOBAL_COLLATE_FNS = {} DGL_GLOBAL_COLLATE_FNS = {}
DGL_GLOBAL_MP_QUEUES = {} DGL_GLOBAL_MP_QUEUES = {}
def init_fn(name, collate_fn, queue): def init_fn(barrier, name, collate_fn, queue):
"""Initialize setting collate function and mp.Queue in the subprocess""" """Initialize setting collate function and mp.Queue in the subprocess"""
global DGL_GLOBAL_COLLATE_FNS global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES global DGL_GLOBAL_MP_QUEUES
DGL_GLOBAL_MP_QUEUES[name] = queue DGL_GLOBAL_MP_QUEUES[name] = queue
DGL_GLOBAL_COLLATE_FNS[name] = collate_fn DGL_GLOBAL_COLLATE_FNS[name] = collate_fn
# sleep here is to ensure this function is executed in all worker processes barrier.wait()
# probably need better solution in the future
time.sleep(1)
return 1 return 1
def cleanup_fn(name): def cleanup_fn(barrier, name):
"""Clean up the data of a dataloader in the worker process""" """Clean up the data of a dataloader in the worker process"""
global DGL_GLOBAL_COLLATE_FNS global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES global DGL_GLOBAL_MP_QUEUES
...@@ -44,7 +41,7 @@ def cleanup_fn(name): ...@@ -44,7 +41,7 @@ def cleanup_fn(name):
del DGL_GLOBAL_COLLATE_FNS[name] del DGL_GLOBAL_COLLATE_FNS[name]
# sleep here is to ensure this function is executed in all worker processes # sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future # probably need better solution in the future
time.sleep(1) barrier.wait()
return 1 return 1
...@@ -52,7 +49,7 @@ def enable_mp_debug(): ...@@ -52,7 +49,7 @@ def enable_mp_debug():
"""Print multiprocessing debug information. This is only """Print multiprocessing debug information. This is only
for debug usage""" for debug usage"""
import logging import logging
logger = multiprocessing.log_to_stderr() logger = mp.log_to_stderr()
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
DATALOADER_ID = 0 DATALOADER_ID = 0
...@@ -122,6 +119,7 @@ class DistDataLoader: ...@@ -122,6 +119,7 @@ class DistDataLoader:
self.current_pos = 0 self.current_pos = 0
if self.pool is not None: if self.pool is not None:
self.m = mp.Manager() self.m = mp.Manager()
self.barrier = self.m.Barrier(self.num_workers)
self.queue = self.m.Queue(maxsize=queue_size) self.queue = self.m.Queue(maxsize=queue_size)
else: else:
self.queue = Queue(maxsize=queue_size) self.queue = Queue(maxsize=queue_size)
...@@ -145,7 +143,7 @@ class DistDataLoader: ...@@ -145,7 +143,7 @@ class DistDataLoader:
results = [] results = []
for _ in range(self.num_workers): for _ in range(self.num_workers):
results.append(self.pool.apply_async( results.append(self.pool.apply_async(
init_fn, args=(self.name, self.collate_fn, self.queue))) init_fn, args=(self.barrier, self.name, self.collate_fn, self.queue)))
for res in results: for res in results:
res.get() res.get()
...@@ -153,7 +151,7 @@ class DistDataLoader: ...@@ -153,7 +151,7 @@ class DistDataLoader:
if self.pool is not None: if self.pool is not None:
results = [] results = []
for _ in range(self.num_workers): for _ in range(self.num_workers):
results.append(self.pool.apply_async(cleanup_fn, args=(self.name,))) results.append(self.pool.apply_async(cleanup_fn, args=(self.barrier, self.name,)))
for res in results: for res in results:
res.get() res.get()
...@@ -162,7 +160,7 @@ class DistDataLoader: ...@@ -162,7 +160,7 @@ class DistDataLoader:
for _ in range(num_reqs): for _ in range(num_reqs):
self._request_next_batch() self._request_next_batch()
if self.recv_idxs < self.expected_idxs: if self.recv_idxs < self.expected_idxs:
result = self.queue.get(timeout=9999) result = self.queue.get(timeout=1800)
self.recv_idxs += 1 self.recv_idxs += 1
self.num_pending -= 1 self.num_pending -= 1
return result return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment