Unverified Commit 794ec4a4 authored by maqy's avatar maqy Committed by GitHub
Browse files

[BugFix] fix unstable sort when using dataloader with HeteroGraph (#4147)



* fix unstable sort

* add torch version check

* reformat

* split too long comments

* Update dataloader.py
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 31e4a89b
......@@ -38,12 +38,13 @@ atexit.register(_set_python_exit_flag)
prefetcher_timeout = int(os.environ.get('DGL_PREFETCHER_TIMEOUT', '30'))
class _TensorizedDatasetIter(object):
def __init__(self, dataset, batch_size, drop_last, mapping_keys):
def __init__(self, dataset, batch_size, drop_last, mapping_keys, shuffle):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.mapping_keys = mapping_keys
self.index = 0
self.shuffle = shuffle
# For PyTorch Lightning compatibility
def __iter__(self):
......@@ -72,7 +73,16 @@ class _TensorizedDatasetIter(object):
# convert the type-ID pairs to dictionary
type_ids = batch[:, 0]
indices = batch[:, 1]
type_ids_sortidx = torch.argsort(type_ids)
if PYTORCH_VER >= LooseVersion("1.10.0"):
_, type_ids_sortidx = torch.sort(type_ids, stable=True)
else:
if not self.shuffle:
dgl_warning(
'The current output_nodes are out of order even if set shuffle '
'to False in Dataloader, the reason is that the current version '
'of torch dose not support stable sort. '
'Please update torch to 1.10.0 or higher to fix it.')
type_ids_sortidx = torch.argsort(type_ids)
type_ids = type_ids[type_ids_sortidx]
indices = indices[type_ids_sortidx]
type_id_uniq, type_id_count = torch.unique_consecutive(type_ids, return_counts=True)
......@@ -122,7 +132,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
"""Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors.
When the dataset is on the GPU, this significantly reduces the overhead.
"""
def __init__(self, indices, batch_size, drop_last):
def __init__(self, indices, batch_size, drop_last, shuffle):
if isinstance(indices, Mapping):
self._mapping_keys = list(indices.keys())
self._device = next(iter(indices.values())).device
......@@ -138,6 +148,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
self._indices = torch.arange(self._id_tensor.shape[0], dtype=torch.int64).share_memory_()
self.batch_size = batch_size
self.drop_last = drop_last
self._shuffle = shuffle
def shuffle(self):
"""Shuffle the dataset."""
......@@ -147,7 +158,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
indices = _divide_by_worker(self._indices, self.batch_size, self.drop_last)
id_tensor = self._id_tensor[indices.to(self._device)]
return _TensorizedDatasetIter(
id_tensor, self.batch_size, self.drop_last, self._mapping_keys)
id_tensor, self.batch_size, self.drop_last, self._mapping_keys, self._shuffle)
def __len__(self):
num_samples = self._id_tensor.shape[0]
......@@ -160,7 +171,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
This class additionally saves the index tensor in shared memory and therefore
avoids duplicating the same index tensor during shuffling.
"""
def __init__(self, indices, batch_size, drop_last, ddp_seed):
def __init__(self, indices, batch_size, drop_last, ddp_seed, shuffle):
if isinstance(indices, Mapping):
self._mapping_keys = list(indices.keys())
len_indices = sum(len(v) for v in indices.values())
......@@ -174,6 +185,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
self.epoch = 0
self.batch_size = batch_size
self.drop_last = drop_last
self._shuffle = shuffle
if self.drop_last and len_indices % self.num_replicas != 0:
self.num_samples = math.ceil((len_indices - self.num_replicas) / self.num_replicas)
......@@ -228,7 +240,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
indices = _divide_by_worker(self._indices[start:end], self.batch_size, self.drop_last)
id_tensor = self._id_tensor[indices.to(self._device)]
return _TensorizedDatasetIter(
id_tensor, self.batch_size, self.drop_last, self._mapping_keys)
id_tensor, self.batch_size, self.drop_last, self._mapping_keys, self._shuffle)
def __len__(self):
return (self.num_samples + (0 if self.drop_last else (self.batch_size - 1))) // \
......@@ -539,15 +551,16 @@ class WorkerInitWrapper(object):
self.func(worker_id)
def create_tensorized_dataset(indices, batch_size, drop_last, use_ddp, ddp_seed):
def create_tensorized_dataset(indices, batch_size, drop_last, use_ddp, ddp_seed,
shuffle):
"""Converts a given indices tensor to a TensorizedDataset, an IterableDataset
that returns views of the original tensor, to reduce overhead from having
a list of scalar tensors in default PyTorch DataLoader implementation.
"""
if use_ddp:
return DDPTensorizedDataset(indices, batch_size, drop_last, ddp_seed)
return DDPTensorizedDataset(indices, batch_size, drop_last, ddp_seed, shuffle)
else:
return TensorizedDataset(indices, batch_size, drop_last)
return TensorizedDataset(indices, batch_size, drop_last, shuffle)
def _get_device(device):
......@@ -814,7 +827,7 @@ class DataLoader(torch.utils.data.DataLoader):
isinstance(indices, Mapping) and
all(torch.is_tensor(v) for v in indices.values()))):
self.dataset = create_tensorized_dataset(
indices, batch_size, drop_last, use_ddp, ddp_seed)
indices, batch_size, drop_last, use_ddp, ddp_seed, shuffle)
else:
self.dataset = indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment