Unverified Commit 44638b93 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix ddp dataloader in heterogeneous cases (#3801)

parent bb6cec23
......@@ -36,7 +36,7 @@ class SAGE(nn.Module):
# example is that the intermediate results can also benefit from prefetching.
g.ndata['h'] = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers,
persistent_workers=(num_workers > 0))
......@@ -77,7 +77,7 @@ def train(rank, world_size, graph, num_classes, split_idx):
graph, train_idx, sampler,
device='cuda', batch_size=1000, shuffle=True, drop_last=False,
num_workers=0, use_ddp=True, use_uva=True)
valid_dataloader = dgl.dataloading.NodeDataLoader(
valid_dataloader = dgl.dataloading.DataLoader(
graph, valid_idx, sampler, device='cuda', batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=True)
......
......@@ -169,8 +169,10 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
def __init__(self, indices, batch_size, drop_last, ddp_seed):
if isinstance(indices, Mapping):
self._mapping_keys = list(indices.keys())
len_indices = sum(len(v) for v in indices.values())
else:
self._mapping_keys = None
len_indices = len(indices)
self.rank = dist.get_rank()
self.num_replicas = dist.get_world_size()
......@@ -179,17 +181,17 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
self.batch_size = batch_size
self.drop_last = drop_last
if self.drop_last and len(indices) % self.num_replicas != 0:
self.num_samples = math.ceil((len(indices) - self.num_replicas) / self.num_replicas)
if self.drop_last and len_indices % self.num_replicas != 0:
self.num_samples = math.ceil((len_indices - self.num_replicas) / self.num_replicas)
else:
self.num_samples = math.ceil(len(indices) / self.num_replicas)
self.num_samples = math.ceil(len_indices / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
# If drop_last is True, we create a shared memory array larger than the number
# of indices since we will need to pad it after shuffling to make it evenly
# divisible before every epoch. If drop_last is False, we create an array
# with the same size as the indices so we can trim it later.
self.shared_mem_size = self.total_size if not self.drop_last else len(indices)
self.num_indices = len(indices)
self.shared_mem_size = self.total_size if not self.drop_last else len_indices
self.num_indices = len_indices
if isinstance(indices, Mapping):
self._device = next(iter(indices.values())).device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment