Unverified Commit 580c7024 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[DataLoader] Allow batch_size=None for GraphDataLoader (#4483)

* overwrite default_collate_fn

* Update dataloader.py

* Update dataloader.py

* Update dataloader.py

* Update dataloader.py

* Update test_dataloader.py

* revert the test code being reverted in #4956
parent e296c468
......@@ -1124,17 +1124,15 @@ class GraphDataLoader(torch.utils.data.DataLoader):
else:
dataloader_kwargs[k] = v
if collate_fn is None:
self.collate = GraphCollator(**collator_kwargs).collate
else:
self.collate = collate_fn
self.use_ddp = use_ddp
if use_ddp:
self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed)
dataloader_kwargs['sampler'] = self.dist_sampler
super().__init__(dataset=dataset, collate_fn=self.collate, **dataloader_kwargs)
if collate_fn is None and kwargs.get('batch_size', 1) is not None:
collate_fn = GraphCollator(**collator_kwargs).collate
super().__init__(dataset=dataset, collate_fn=collate_fn, **dataloader_kwargs)
def set_epoch(self, epoch):
"""Sets the epoch number for the underlying sampler which ensures all replicas
......
......@@ -12,15 +12,21 @@ from test_utils import parametrize_idtype
import pytest
def test_graph_dataloader():
batch_size = 16
@pytest.mark.parametrize('batch_size', [None, 16])
def test_graph_dataloader(batch_size):
num_batches = 2
minigc_dataset = dgl.data.MiniGCDataset(batch_size * num_batches, 10, 20)
num_samples = num_batches * (batch_size if batch_size is not None else 1)
minigc_dataset = dgl.data.MiniGCDataset(num_samples, 10, 20)
data_loader = dgl.dataloading.GraphDataLoader(minigc_dataset, batch_size=batch_size, shuffle=True)
assert isinstance(iter(data_loader), Iterator)
for graph, label in data_loader:
assert isinstance(graph, dgl.DGLGraph)
if batch_size is not None:
assert F.asnumpy(label).shape[0] == batch_size
else:
# If batch size is None, the label element will be a single scalar following
# PyTorch's practice.
assert F.asnumpy(label).ndim == 0
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize('num_workers', [0, 4])
......
......@@ -1516,7 +1516,7 @@ def test_hgt(idtype, in_size, num_heads):
sorted_y = m(sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False)
assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)
# mini-batch
train_idx = th.randint(0, 100, (10, ), dtype = idtype)
train_idx = th.randperm(100, dtype = idtype)[:10]
sampler = dgl.dataloading.NeighborSampler([-1])
train_loader = dgl.dataloading.DataLoader(g, train_idx.to(dev), sampler,
batch_size=8, device=dev,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment