Unverified Commit 4c147814 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Optimization] Memory consumption optimization on index shuffling in dataloader (#3980)

* fix

* revert

* Update dataloader.py
parent 65e6b04d
...@@ -11,6 +11,7 @@ import atexit ...@@ -11,6 +11,7 @@ import atexit
import os import os
import psutil import psutil
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
...@@ -134,15 +135,13 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -134,15 +135,13 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
# Use a shared memory array to permute indices for shuffling. This is to make sure that # Use a shared memory array to permute indices for shuffling. This is to make sure that
# the worker processes can see it when persistent_workers=True, where self._indices # the worker processes can see it when persistent_workers=True, where self._indices
# would not be duplicated every epoch. # would not be duplicated every epoch.
self._indices = torch.empty(self._id_tensor.shape[0], dtype=torch.int64).share_memory_() self._indices = torch.arange(self._id_tensor.shape[0], dtype=torch.int64).share_memory_()
self._indices[:] = torch.arange(self._id_tensor.shape[0])
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.drop_last = drop_last
def shuffle(self): def shuffle(self):
"""Shuffle the dataset.""" """Shuffle the dataset."""
# TODO: may need an in-place shuffle kernel np.random.shuffle(self._indices.numpy())
self._indices[:] = self._indices[torch.randperm(self._indices.shape[0])]
def __iter__(self): def __iter__(self):
indices = _divide_by_worker(self._indices, self.batch_size, self.drop_last) indices = _divide_by_worker(self._indices, self.batch_size, self.drop_last)
...@@ -203,16 +202,20 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -203,16 +202,20 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
def _create_shared_indices(self): def _create_shared_indices(self):
indices = torch.empty(self.shared_mem_size, dtype=torch.int64) indices = torch.empty(self.shared_mem_size, dtype=torch.int64)
num_ids = self._id_tensor.shape[0] num_ids = self._id_tensor.shape[0]
indices[:num_ids] = torch.arange(num_ids) torch.arange(num_ids, out=indices[:num_ids])
indices[num_ids:] = torch.arange(self.shared_mem_size - num_ids) torch.arange(self.shared_mem_size - num_ids, out=indices[num_ids:])
return indices return indices
def shuffle(self): def shuffle(self):
"""Shuffles the dataset.""" """Shuffles the dataset."""
# Only rank 0 does the actual shuffling. The other ranks wait for it. # Only rank 0 does the actual shuffling. The other ranks wait for it.
if self.rank == 0: if self.rank == 0:
if self._device == torch.device('cpu'):
np.random.shuffle(self._indices[:self.num_indices].numpy())
else:
self._indices[:self.num_indices] = self._indices[ self._indices[:self.num_indices] = self._indices[
torch.randperm(self.num_indices, device=self._device)] torch.randperm(self.num_indices, device=self._device)]
if not self.drop_last: if not self.drop_last:
# pad extra # pad extra
self._indices[self.num_indices:] = \ self._indices[self.num_indices:] = \
...@@ -514,9 +517,10 @@ class CollateWrapper(object): ...@@ -514,9 +517,10 @@ class CollateWrapper(object):
self.device = device self.device = device
def __call__(self, items): def __call__(self, items):
if self.use_uva or (self.g.device != torch.device('cpu')): graph_device = getattr(self.g, 'device', None)
# Only copy the indices to the given device if in UVA mode or the graph is not on if self.use_uva or (graph_device != torch.device('cpu')):
# CPU. # Only copy the indices to the given device if in UVA mode or the graph
# is not on CPU.
items = recursive_apply(items, lambda x: x.to(self.device)) items = recursive_apply(items, lambda x: x.to(self.device))
batch = self.sample_func(self.g, items) batch = self.sample_func(self.g, items)
return recursive_apply(batch, remove_parent_storage_columns, self.g) return recursive_apply(batch, remove_parent_storage_columns, self.g)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment