Unverified Commit 5ebd3bf0 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[DataLoader] Disable the usage of shared memory when persistent_workers=False on single GPU (#4497)



* toggle shared memory usage

* Update dataloader.py
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 2630d2eb
......@@ -134,7 +134,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
"""Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors.
When the dataset is on the GPU, this significantly reduces the overhead.
"""
def __init__(self, indices, batch_size, drop_last, shuffle):
def __init__(self, indices, batch_size, drop_last, shuffle, use_shared_memory):
if isinstance(indices, Mapping):
self._mapping_keys = list(indices.keys())
self._device = next(iter(indices.values())).device
......@@ -147,7 +147,9 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
# Use a shared memory array to permute indices for shuffling. This is to make sure that
# the worker processes can see it when persistent_workers=True, where self._indices
# would not be duplicated every epoch.
self._indices = torch.arange(self._id_tensor.shape[0], dtype=torch.int64).share_memory_()
self._indices = torch.arange(self._id_tensor.shape[0], dtype=torch.int64)
if use_shared_memory:
self._indices.share_memory_()
self.batch_size = batch_size
self.drop_last = drop_last
self._shuffle = shuffle
......@@ -551,15 +553,16 @@ class WorkerInitWrapper(object):
def create_tensorized_dataset(indices, batch_size, drop_last, use_ddp, ddp_seed,
shuffle):
shuffle, use_shared_memory):
"""Converts a given indices tensor to a TensorizedDataset, an IterableDataset
that returns views of the original tensor, to reduce overhead from having
a list of scalar tensors in default PyTorch DataLoader implementation.
"""
if use_ddp:
# DDP always uses shared memory
return DDPTensorizedDataset(indices, batch_size, drop_last, ddp_seed, shuffle)
else:
return TensorizedDataset(indices, batch_size, drop_last, shuffle)
return TensorizedDataset(indices, batch_size, drop_last, shuffle, use_shared_memory)
def _get_device(device):
......@@ -825,7 +828,8 @@ class DataLoader(torch.utils.data.DataLoader):
isinstance(indices, Mapping) and
all(torch.is_tensor(v) for v in indices.values()))):
self.dataset = create_tensorized_dataset(
indices, batch_size, drop_last, use_ddp, ddp_seed, shuffle)
indices, batch_size, drop_last, use_ddp, ddp_seed, shuffle,
kwargs.get('persistent_workers', False))
else:
self.dataset = indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment