Unverified Commit 408eba24 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

Multiple dist dl sampler (#4704)

* [Dist] enable iterate multiple dist dataloaders simultaneously

* format file

* add support for any number of dataloaders

* fix lint

* refine code
parent ea48ce7a
......@@ -12,6 +12,7 @@ import traceback
from enum import Enum
from .. import utils
from ..base import DGLError
from . import rpc
from .constants import MAX_QUEUE_SIZE
from .kvstore import close_kvstore, init_kvstore
......@@ -122,8 +123,11 @@ class CustomPool:
"""
ctx = mp.get_context("spawn")
self.num_workers = num_workers
self.queue_size = num_workers * 4
# As pool could be used by any number of dataloaders, queues
# should be able to take infinite elements to avoid dead lock.
self.queue_size = 0
self.result_queue = ctx.Queue(self.queue_size)
self.results = {} # key is dataloader name, value is fetched batch.
self.task_queues = []
self.process_list = []
self.current_proc_id = 0
......@@ -149,6 +153,7 @@ class CustomPool:
self.task_queues[i].put(
(MpCommand.SET_COLLATE_FN, (dataloader_name, func))
)
self.results[dataloader_name] = []
def submit_task(self, dataloader_name, args):
"""Submit task to workers"""
......@@ -167,9 +172,14 @@ class CustomPool:
def get_result(self, dataloader_name, timeout=1800):
"""Get result from result queue"""
result_dataloader_name, result = self.result_queue.get(timeout=timeout)
assert result_dataloader_name == dataloader_name
return result
if dataloader_name not in self.results:
raise DGLError(
f"Got result from an unknown dataloader {dataloader_name}."
)
while len(self.results[dataloader_name]) == 0:
dl_name, data = self.result_queue.get(timeout=timeout)
self.results[dl_name].append(data)
return self.results[dataloader_name].pop(0)
def delete_collate_fn(self, dataloader_name):
"""Delete collate function"""
......@@ -177,6 +187,7 @@ class CustomPool:
self.task_queues[i].put(
(MpCommand.DELETE_COLLATE_FN, (dataloader_name,))
)
del self.results[dataloader_name]
def call_barrier(self):
"""Call barrier at all workers"""
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment