"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "025e43210f724ea3a5f308463166e61b4cd97399"
Unverified Commit 794ec4a4 authored by maqy's avatar maqy Committed by GitHub
Browse files

[BugFix] fix unstable sort when using dataloader with HeteroGraph (#4147)



* fix unstable sort

* add torch version check

* reformat

* split too long comments

* Update dataloader.py
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 31e4a89b
...@@ -38,12 +38,13 @@ atexit.register(_set_python_exit_flag) ...@@ -38,12 +38,13 @@ atexit.register(_set_python_exit_flag)
prefetcher_timeout = int(os.environ.get('DGL_PREFETCHER_TIMEOUT', '30')) prefetcher_timeout = int(os.environ.get('DGL_PREFETCHER_TIMEOUT', '30'))
class _TensorizedDatasetIter(object): class _TensorizedDatasetIter(object):
def __init__(self, dataset, batch_size, drop_last, mapping_keys): def __init__(self, dataset, batch_size, drop_last, mapping_keys, shuffle):
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.drop_last = drop_last
self.mapping_keys = mapping_keys self.mapping_keys = mapping_keys
self.index = 0 self.index = 0
self.shuffle = shuffle
# For PyTorch Lightning compatibility # For PyTorch Lightning compatibility
def __iter__(self): def __iter__(self):
...@@ -72,7 +73,16 @@ class _TensorizedDatasetIter(object): ...@@ -72,7 +73,16 @@ class _TensorizedDatasetIter(object):
# convert the type-ID pairs to dictionary # convert the type-ID pairs to dictionary
type_ids = batch[:, 0] type_ids = batch[:, 0]
indices = batch[:, 1] indices = batch[:, 1]
type_ids_sortidx = torch.argsort(type_ids) if PYTORCH_VER >= LooseVersion("1.10.0"):
_, type_ids_sortidx = torch.sort(type_ids, stable=True)
else:
if not self.shuffle:
dgl_warning(
'The current output_nodes are out of order even if set shuffle '
'to False in Dataloader, the reason is that the current version '
'of torch dose not support stable sort. '
'Please update torch to 1.10.0 or higher to fix it.')
type_ids_sortidx = torch.argsort(type_ids)
type_ids = type_ids[type_ids_sortidx] type_ids = type_ids[type_ids_sortidx]
indices = indices[type_ids_sortidx] indices = indices[type_ids_sortidx]
type_id_uniq, type_id_count = torch.unique_consecutive(type_ids, return_counts=True) type_id_uniq, type_id_count = torch.unique_consecutive(type_ids, return_counts=True)
...@@ -122,7 +132,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -122,7 +132,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
"""Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors. """Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors.
When the dataset is on the GPU, this significantly reduces the overhead. When the dataset is on the GPU, this significantly reduces the overhead.
""" """
def __init__(self, indices, batch_size, drop_last): def __init__(self, indices, batch_size, drop_last, shuffle):
if isinstance(indices, Mapping): if isinstance(indices, Mapping):
self._mapping_keys = list(indices.keys()) self._mapping_keys = list(indices.keys())
self._device = next(iter(indices.values())).device self._device = next(iter(indices.values())).device
...@@ -138,6 +148,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -138,6 +148,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
self._indices = torch.arange(self._id_tensor.shape[0], dtype=torch.int64).share_memory_() self._indices = torch.arange(self._id_tensor.shape[0], dtype=torch.int64).share_memory_()
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.drop_last = drop_last
self._shuffle = shuffle
def shuffle(self): def shuffle(self):
"""Shuffle the dataset.""" """Shuffle the dataset."""
...@@ -147,7 +158,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset): ...@@ -147,7 +158,7 @@ class TensorizedDataset(torch.utils.data.IterableDataset):
indices = _divide_by_worker(self._indices, self.batch_size, self.drop_last) indices = _divide_by_worker(self._indices, self.batch_size, self.drop_last)
id_tensor = self._id_tensor[indices.to(self._device)] id_tensor = self._id_tensor[indices.to(self._device)]
return _TensorizedDatasetIter( return _TensorizedDatasetIter(
id_tensor, self.batch_size, self.drop_last, self._mapping_keys) id_tensor, self.batch_size, self.drop_last, self._mapping_keys, self._shuffle)
def __len__(self): def __len__(self):
num_samples = self._id_tensor.shape[0] num_samples = self._id_tensor.shape[0]
...@@ -160,7 +171,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -160,7 +171,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
This class additionally saves the index tensor in shared memory and therefore This class additionally saves the index tensor in shared memory and therefore
avoids duplicating the same index tensor during shuffling. avoids duplicating the same index tensor during shuffling.
""" """
def __init__(self, indices, batch_size, drop_last, ddp_seed): def __init__(self, indices, batch_size, drop_last, ddp_seed, shuffle):
if isinstance(indices, Mapping): if isinstance(indices, Mapping):
self._mapping_keys = list(indices.keys()) self._mapping_keys = list(indices.keys())
len_indices = sum(len(v) for v in indices.values()) len_indices = sum(len(v) for v in indices.values())
...@@ -174,6 +185,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -174,6 +185,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
self.epoch = 0 self.epoch = 0
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.drop_last = drop_last
self._shuffle = shuffle
if self.drop_last and len_indices % self.num_replicas != 0: if self.drop_last and len_indices % self.num_replicas != 0:
self.num_samples = math.ceil((len_indices - self.num_replicas) / self.num_replicas) self.num_samples = math.ceil((len_indices - self.num_replicas) / self.num_replicas)
...@@ -228,7 +240,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset): ...@@ -228,7 +240,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
indices = _divide_by_worker(self._indices[start:end], self.batch_size, self.drop_last) indices = _divide_by_worker(self._indices[start:end], self.batch_size, self.drop_last)
id_tensor = self._id_tensor[indices.to(self._device)] id_tensor = self._id_tensor[indices.to(self._device)]
return _TensorizedDatasetIter( return _TensorizedDatasetIter(
id_tensor, self.batch_size, self.drop_last, self._mapping_keys) id_tensor, self.batch_size, self.drop_last, self._mapping_keys, self._shuffle)
def __len__(self): def __len__(self):
return (self.num_samples + (0 if self.drop_last else (self.batch_size - 1))) // \ return (self.num_samples + (0 if self.drop_last else (self.batch_size - 1))) // \
...@@ -539,15 +551,16 @@ class WorkerInitWrapper(object): ...@@ -539,15 +551,16 @@ class WorkerInitWrapper(object):
self.func(worker_id) self.func(worker_id)
def create_tensorized_dataset(indices, batch_size, drop_last, use_ddp, ddp_seed): def create_tensorized_dataset(indices, batch_size, drop_last, use_ddp, ddp_seed,
shuffle):
"""Converts a given indices tensor to a TensorizedDataset, an IterableDataset """Converts a given indices tensor to a TensorizedDataset, an IterableDataset
that returns views of the original tensor, to reduce overhead from having that returns views of the original tensor, to reduce overhead from having
a list of scalar tensors in default PyTorch DataLoader implementation. a list of scalar tensors in default PyTorch DataLoader implementation.
""" """
if use_ddp: if use_ddp:
return DDPTensorizedDataset(indices, batch_size, drop_last, ddp_seed) return DDPTensorizedDataset(indices, batch_size, drop_last, ddp_seed, shuffle)
else: else:
return TensorizedDataset(indices, batch_size, drop_last) return TensorizedDataset(indices, batch_size, drop_last, shuffle)
def _get_device(device): def _get_device(device):
...@@ -814,7 +827,7 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -814,7 +827,7 @@ class DataLoader(torch.utils.data.DataLoader):
isinstance(indices, Mapping) and isinstance(indices, Mapping) and
all(torch.is_tensor(v) for v in indices.values()))): all(torch.is_tensor(v) for v in indices.values()))):
self.dataset = create_tensorized_dataset( self.dataset = create_tensorized_dataset(
indices, batch_size, drop_last, use_ddp, ddp_seed) indices, batch_size, drop_last, use_ddp, ddp_seed, shuffle)
else: else:
self.dataset = indices self.dataset = indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment