Unverified Commit 33e80452 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Set persistent_workers in MultiProcessDataLoader (#6592)

parent 81ac9d27
......@@ -78,14 +78,19 @@ class MultiprocessingWrapper(dp.iter.IterDataPipe):
num_workers : int, optional
The number of worker processes. Default is 0, meaning that there
will be no multiprocessing.
persistent_workers : bool, optional
If True, the data loader will not shut down the worker processes after a
dataset has been consumed once. This allows to maintain the workers
instances alive.
"""
def __init__(self, datapipe, num_workers=0):
def __init__(self, datapipe, num_workers=0, persistent_workers=True):
self.datapipe = datapipe
self.dataloader = torch.utils.data.DataLoader(
datapipe,
batch_size=None,
num_workers=num_workers,
persistent_workers=(num_workers > 0) and persistent_workers,
)
def __iter__(self):
......@@ -109,9 +114,13 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
num_workers : int, optional
Number of worker processes. Default is 0, which is identical to
:class:`SingleProcessDataLoader`.
persistent_workers : bool, optional
If True, the data loader will not shut down the worker processes after a
dataset has been consumed once. This allows to maintain the workers
instances alive.
"""
def __init__(self, datapipe, num_workers=0):
def __init__(self, datapipe, num_workers=0, persistent_workers=True):
# Multiprocessing requires two modifications to the datapipe:
#
# 1. Insert a stage after ItemSampler to distribute the
......@@ -144,6 +153,7 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
FeatureFetcher,
MultiprocessingWrapper,
num_workers=num_workers,
persistent_workers=persistent_workers,
)
# (3) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment