Unverified Commit 5ebd3bf0 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[DataLoader] Disable the usage of shared memory when persistent_workers=False on single GPU (#4497)



* toggle shared memory usage

* Update dataloader.py
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 2630d2eb
...@@ -134,7 +134,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -134,7 +134,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
"""Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors. """Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors.
When the dataset is on the GPU, this significantly reduces the overhead. When the dataset is on the GPU, this significantly reduces the overhead.
""" """
def __init__(self, indices, batch_size, drop_last, shuffle): def __init__(self, indices, batch_size, drop_last, shuffle, use_shared_memory):
if isinstance(indices, Mapping): if isinstance(indices, Mapping):
self._mapping_keys = list(indices.keys()) self._mapping_keys = list(indices.keys())
self._device = next(iter(indices.values())).device self._device = next(iter(indices.values())).device
...@@ -147,7 +147,9 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -147,7 +147,9 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
# Use a shared memory array to permute indices for shuffling. This is to make sure that # Use a shared memory array to permute indices for shuffling. This is to make sure that
# the worker processes can see it when persistent_workers=True, where self._indices # the worker processes can see it when persistent_workers=True, where self._indices
# would not be duplicated every epoch. # would not be duplicated every epoch.
self._indices = torch.arange(self._id_tensor.shape[0], dtype=torch.int64).share_memory_() self._indices = torch.arange(self._id_tensor.shape[0], dtype=torch.int64)
if use_shared_memory:
self._indices.share_memory_()
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.drop_last = drop_last
self._shuffle = shuffle self._shuffle = shuffle
...@@ -551,15 +553,16 @@ class WorkerInitWrapper(object): ...@@ -551,15 +553,16 @@ class WorkerInitWrapper(object):
def create_tensorized_dataset(indices, batch_size, drop_last, use_ddp, ddp_seed, def create_tensorized_dataset(indices, batch_size, drop_last, use_ddp, ddp_seed,
shuffle): shuffle, use_shared_memory):
"""Converts a given indices tensor to a TensorizedDataset, an IterableDataset """Converts a given indices tensor to a TensorizedDataset, an IterableDataset
that returns views of the original tensor, to reduce overhead from having that returns views of the original tensor, to reduce overhead from having
a list of scalar tensors in default PyTorch DataLoader implementation. a list of scalar tensors in default PyTorch DataLoader implementation.
""" """
if use_ddp: if use_ddp:
# DDP always uses shared memory
return DDPTensorizedDataset(indices, batch_size, drop_last, ddp_seed, shuffle) return DDPTensorizedDataset(indices, batch_size, drop_last, ddp_seed, shuffle)
else: else:
return TensorizedDataset(indices, batch_size, drop_last, shuffle) return TensorizedDataset(indices, batch_size, drop_last, shuffle, use_shared_memory)
def _get_device(device): def _get_device(device):
...@@ -825,7 +828,8 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -825,7 +828,8 @@ class DataLoader(torch.utils.data.DataLoader):
isinstance(indices, Mapping) and isinstance(indices, Mapping) and
all(torch.is_tensor(v) for v in indices.values()))): all(torch.is_tensor(v) for v in indices.values()))):
self.dataset = create_tensorized_dataset( self.dataset = create_tensorized_dataset(
indices, batch_size, drop_last, use_ddp, ddp_seed, shuffle) indices, batch_size, drop_last, use_ddp, ddp_seed, shuffle,
kwargs.get('persistent_workers', False))
else: else:
self.dataset = indices self.dataset = indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment