"docs/git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "5b35624528a1babed91fbb77f8981a62ede36b1f"
Unverified Commit 33e80452 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Set persistent_workers in MultiProcessDataLoader (#6592)

parent 81ac9d27
...@@ -78,14 +78,19 @@ class MultiprocessingWrapper(dp.iter.IterDataPipe): ...@@ -78,14 +78,19 @@ class MultiprocessingWrapper(dp.iter.IterDataPipe):
num_workers : int, optional num_workers : int, optional
The number of worker processes. Default is 0, meaning that there The number of worker processes. Default is 0, meaning that there
will be no multiprocessing. will be no multiprocessing.
persistent_workers : bool, optional
If True, the data loader will not shut down the worker processes after a
dataset has been consumed once. This allows to maintain the workers
instances alive.
""" """
def __init__(self, datapipe, num_workers=0): def __init__(self, datapipe, num_workers=0, persistent_workers=True):
self.datapipe = datapipe self.datapipe = datapipe
self.dataloader = torch.utils.data.DataLoader( self.dataloader = torch.utils.data.DataLoader(
datapipe, datapipe,
batch_size=None, batch_size=None,
num_workers=num_workers, num_workers=num_workers,
persistent_workers=(num_workers > 0) and persistent_workers,
) )
def __iter__(self): def __iter__(self):
...@@ -109,9 +114,13 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader): ...@@ -109,9 +114,13 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
num_workers : int, optional num_workers : int, optional
Number of worker processes. Default is 0, which is identical to Number of worker processes. Default is 0, which is identical to
:class:`SingleProcessDataLoader`. :class:`SingleProcessDataLoader`.
persistent_workers : bool, optional
If True, the data loader will not shut down the worker processes after a
dataset has been consumed once. This allows to maintain the workers
instances alive.
""" """
def __init__(self, datapipe, num_workers=0): def __init__(self, datapipe, num_workers=0, persistent_workers=True):
# Multiprocessing requires two modifications to the datapipe: # Multiprocessing requires two modifications to the datapipe:
# #
# 1. Insert a stage after ItemSampler to distribute the # 1. Insert a stage after ItemSampler to distribute the
...@@ -144,6 +153,7 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader): ...@@ -144,6 +153,7 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
FeatureFetcher, FeatureFetcher,
MultiprocessingWrapper, MultiprocessingWrapper,
num_workers=num_workers, num_workers=num_workers,
persistent_workers=persistent_workers,
) )
# (3) Cut datapipe at CopyTo and wrap with prefetcher. This enables the # (3) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment