Unverified Commit bab32d5b authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Fix dataloader (#1970)

* fix dataloader.

* initialize iterator of DistDataloader correctly.

* update test.
parent 264d96cd
......@@ -10,24 +10,37 @@ from .. import backend as F
__all__ = ["DistDataLoader"]
def call_collate_fn(next_data):
def call_collate_fn(name, next_data):
"""Call collate function"""
try:
result = DGL_GLOBAL_COLLATE_FN(next_data)
DGL_GLOBAL_MP_QUEUE.put(result)
result = DGL_GLOBAL_COLLATE_FNS[name](next_data)
DGL_GLOBAL_MP_QUEUES[name].put(result)
except Exception as e:
traceback.print_exc()
print(e)
raise e
return 1
DGL_GLOBAL_COLLATE_FNS = {}
DGL_GLOBAL_MP_QUEUES = {}
def init_fn(collate_fn, queue):
def init_fn(name, collate_fn, queue):
"""Initialize setting collate function and mp.Queue in the subprocess"""
global DGL_GLOBAL_COLLATE_FN
global DGL_GLOBAL_MP_QUEUE
DGL_GLOBAL_MP_QUEUE = queue
DGL_GLOBAL_COLLATE_FN = collate_fn
global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES
DGL_GLOBAL_MP_QUEUES[name] = queue
DGL_GLOBAL_COLLATE_FNS[name] = collate_fn
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
time.sleep(1)
return 1
def cleanup_fn(name):
"""Clean up the data of a dataloader in the worker process"""
global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES
del DGL_GLOBAL_MP_QUEUES[name]
del DGL_GLOBAL_COLLATE_FNS[name]
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
time.sleep(1)
......@@ -41,6 +54,7 @@ def enable_mp_debug():
logger = multiprocessing.log_to_stderr()
logger.setLevel(logging.DEBUG)
DATALOADER_ID = 0
class DistDataLoader:
"""DGL customized multiprocessing dataloader, which is designed for using with DistGraph."""
......@@ -90,19 +104,32 @@ class DistDataLoader:
if self.pool is None:
ctx = mp.get_context("spawn")
self.pool = ctx.Pool(num_workers)
results = []
for _ in range(num_workers):
results.append(self.pool.apply_async(
init_fn, args=(collate_fn, self.queue)))
for res in results:
res.get()
self.dataset = F.tensor(dataset)
self.expected_idxs = len(dataset) // self.batch_size
if not self.drop_last and len(dataset) % self.batch_size != 0:
self.expected_idxs += 1
# We need to have a unique Id for each data loader to identify itself
# in the sampler processes.
global DATALOADER_ID
self.name = "dataloader-" + str(DATALOADER_ID)
DATALOADER_ID += 1
results = []
for _ in range(self.num_workers):
results.append(self.pool.apply_async(
init_fn, args=(self.name, self.collate_fn, self.queue)))
for res in results:
res.get()
def __del__(self):
results = []
for _ in range(self.num_workers):
results.append(self.pool.apply_async(cleanup_fn, args=(self.name,)))
for res in results:
res.get()
def __next__(self):
if not self.started:
for _ in range(self.queue_size):
......@@ -113,13 +140,13 @@ class DistDataLoader:
self.recv_idxs += 1
return result
else:
self.recv_idxs = 0
self.current_pos = 0
raise StopIteration
def __iter__(self):
if self.shuffle:
self.dataset = F.rand_shuffle(self.dataset)
self.recv_idxs = 0
self.current_pos = 0
return self
def _request_next_batch(self):
......@@ -128,7 +155,7 @@ class DistDataLoader:
return None
else:
async_result = self.pool.apply_async(
call_collate_fn, args=(next_data, ))
call_collate_fn, args=(self.name, next_data, ))
return async_result
def _next_data(self):
......
......@@ -62,6 +62,8 @@ def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last):
sampler = NeighborSampler(dist_graph, [5, 10],
dgl.distributed.sample_neighbors)
# We need to test creating DistDataLoader multiple times.
for i in range(2):
# Create DataLoader for constructing blocks
dataloader = DistDataLoader(
dataset=train_nid.numpy(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment