Unverified Commit 408eba24 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

Multiple dist dl sampler (#4704)

* [Dist] enable iterate multiple dist dataloaders simultaneously

* format file

* add support for any number of dataloaders

* fix lint

* refine code
parent ea48ce7a
...@@ -12,6 +12,7 @@ import traceback ...@@ -12,6 +12,7 @@ import traceback
from enum import Enum from enum import Enum
from .. import utils from .. import utils
from ..base import DGLError
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE from .constants import MAX_QUEUE_SIZE
from .kvstore import close_kvstore, init_kvstore from .kvstore import close_kvstore, init_kvstore
...@@ -122,8 +123,11 @@ class CustomPool: ...@@ -122,8 +123,11 @@ class CustomPool:
""" """
ctx = mp.get_context("spawn") ctx = mp.get_context("spawn")
self.num_workers = num_workers self.num_workers = num_workers
self.queue_size = num_workers * 4 # As pool could be used by any number of dataloaders, queues
# should be able to take infinite elements to avoid dead lock.
self.queue_size = 0
self.result_queue = ctx.Queue(self.queue_size) self.result_queue = ctx.Queue(self.queue_size)
self.results = {} # key is dataloader name, value is fetched batch.
self.task_queues = [] self.task_queues = []
self.process_list = [] self.process_list = []
self.current_proc_id = 0 self.current_proc_id = 0
...@@ -149,6 +153,7 @@ class CustomPool: ...@@ -149,6 +153,7 @@ class CustomPool:
self.task_queues[i].put( self.task_queues[i].put(
(MpCommand.SET_COLLATE_FN, (dataloader_name, func)) (MpCommand.SET_COLLATE_FN, (dataloader_name, func))
) )
self.results[dataloader_name] = []
def submit_task(self, dataloader_name, args): def submit_task(self, dataloader_name, args):
"""Submit task to workers""" """Submit task to workers"""
...@@ -167,9 +172,14 @@ class CustomPool: ...@@ -167,9 +172,14 @@ class CustomPool:
def get_result(self, dataloader_name, timeout=1800): def get_result(self, dataloader_name, timeout=1800):
"""Get result from result queue""" """Get result from result queue"""
result_dataloader_name, result = self.result_queue.get(timeout=timeout) if dataloader_name not in self.results:
assert result_dataloader_name == dataloader_name raise DGLError(
return result f"Got result from an unknown dataloader {dataloader_name}."
)
while len(self.results[dataloader_name]) == 0:
dl_name, data = self.result_queue.get(timeout=timeout)
self.results[dl_name].append(data)
return self.results[dataloader_name].pop(0)
def delete_collate_fn(self, dataloader_name): def delete_collate_fn(self, dataloader_name):
"""Delete collate function""" """Delete collate function"""
...@@ -177,6 +187,7 @@ class CustomPool: ...@@ -177,6 +187,7 @@ class CustomPool:
self.task_queues[i].put( self.task_queues[i].put(
(MpCommand.DELETE_COLLATE_FN, (dataloader_name,)) (MpCommand.DELETE_COLLATE_FN, (dataloader_name,))
) )
del self.results[dataloader_name]
def call_barrier(self): def call_barrier(self):
"""Call barrier at all workers""" """Call barrier at all workers"""
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment