Unverified Commit 4c147814 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Optimization] Memory consumption optimization on index shuffling in dataloader (#3980)

* fix

* revert

* Update dataloader.py
parent 65e6b04d
......@@ -11,6 +11,7 @@ import atexit
import os
import psutil
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
......@@ -134,15 +135,13 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
# Use a shared memory array to permute indices for shuffling. This is to make sure that
# the worker processes can see it when persistent_workers=True, where self._indices
# would not be duplicated every epoch.
self._indices = torch.empty(self._id_tensor.shape[0], dtype=torch.int64).share_memory_()
self._indices[:] = torch.arange(self._id_tensor.shape[0])
self._indices = torch.arange(self._id_tensor.shape[0], dtype=torch.int64).share_memory_()
self.batch_size = batch_size
self.drop_last = drop_last
def shuffle(self):
"""Shuffle the dataset."""
# TODO: may need an in-place shuffle kernel
self._indices[:] = self._indices[torch.randperm(self._indices.shape[0])]
np.random.shuffle(self._indices.numpy())
def __iter__(self):
indices = _divide_by_worker(self._indices, self.batch_size, self.drop_last)
......@@ -203,16 +202,20 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
def _create_shared_indices(self):
indices = torch.empty(self.shared_mem_size, dtype=torch.int64)
num_ids = self._id_tensor.shape[0]
indices[:num_ids] = torch.arange(num_ids)
indices[num_ids:] = torch.arange(self.shared_mem_size - num_ids)
torch.arange(num_ids, out=indices[:num_ids])
torch.arange(self.shared_mem_size - num_ids, out=indices[num_ids:])
return indices
def shuffle(self):
"""Shuffles the dataset."""
# Only rank 0 does the actual shuffling. The other ranks wait for it.
if self.rank == 0:
if self._device == torch.device('cpu'):
np.random.shuffle(self._indices[:self.num_indices].numpy())
else:
self._indices[:self.num_indices] = self._indices[
torch.randperm(self.num_indices, device=self._device)]
if not self.drop_last:
# pad extra
self._indices[self.num_indices:] = \
......@@ -514,9 +517,10 @@ class CollateWrapper(object):
self.device = device
def __call__(self, items):
if self.use_uva or (self.g.device != torch.device('cpu')):
# Only copy the indices to the given device if in UVA mode or the graph is not on
# CPU.
graph_device = getattr(self.g, 'device', None)
if self.use_uva or (graph_device != torch.device('cpu')):
# Only copy the indices to the given device if in UVA mode or the graph
# is not on CPU.
items = recursive_apply(items, lambda x: x.to(self.device))
batch = self.sample_func(self.g, items)
return recursive_apply(batch, remove_parent_storage_columns, self.g)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment