"src/vscode:/vscode.git/clone" did not exist on "a94977b8b32b94ccd00d2f8f812aadb46764baba"
Unverified Commit bfef789e authored by Justus Schock's avatar Justus Schock Committed by GitHub
Browse files

[Dataloading] Make loader iters iterator (#2886)



* Make loader items iterator

* Update test_dataloader.py

* Update __init__.py

* Update test_dataloader.py
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent fbfcf1a8
......@@ -15,6 +15,9 @@ class _ScalarDataBatcherIter:
self.index = 0
self.drop_last = drop_last
def __iter__(self):
return self
def __next__(self):
num_items = self.dataset.shape[0]
if self.index >= num_items:
......@@ -214,6 +217,9 @@ class _NodeDataLoaderIter:
self.node_dataloader = node_dataloader
self.iter_ = iter(node_dataloader.dataloader)
def __iter__(self):
return self
def __next__(self):
# input_nodes, output_nodes, blocks
result_ = next(self.iter_)
......@@ -228,6 +234,9 @@ class _EdgeDataLoaderIter:
self.edge_dataloader = edge_dataloader
self.iter_ = iter(edge_dataloader.dataloader)
def __iter__(self):
return self
def __next__(self):
result_ = next(self.iter_)
......
......@@ -3,6 +3,7 @@ import backend as F
import unittest
from torch.utils.data import DataLoader
from collections import defaultdict
from collections.abc import Iterator
from itertools import product
def _check_neighbor_sampling_dataloader(g, nids, dl, mode, collator):
......@@ -179,6 +180,7 @@ def test_neighbor_sampler_dataloader():
for _g, nid, collator, mode in zip(graphs, nids, collators, modes):
dl = DataLoader(
collator.dataset, collate_fn=collator.collate, batch_size=2, shuffle=True, drop_last=False)
assert isinstance(iter(dl), Iterator)
_check_neighbor_sampling_dataloader(_g, nid, dl, mode, collator)
def test_graph_dataloader():
......@@ -186,6 +188,7 @@ def test_graph_dataloader():
num_batches = 2
minigc_dataset = dgl.data.MiniGCDataset(batch_size * num_batches, 10, 20)
data_loader = dgl.dataloading.GraphDataLoader(minigc_dataset, batch_size=batch_size, shuffle=True)
assert isinstance(iter(data_loader), Iterator)
for graph, label in data_loader:
assert isinstance(graph, dgl.DGLGraph)
assert F.asnumpy(label).shape[0] == batch_size
......@@ -226,6 +229,7 @@ def test_node_dataloader():
dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), batch_size=batch_size)
assert isinstance(iter(dataloader), Iterator)
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
......@@ -280,6 +284,8 @@ def test_edge_dataloader():
g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
sampler, device=F.ctx(), negative_sampler=neg_sampler,
batch_size=batch_size)
assert isinstance(iter(dataloader), Iterator)
for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
_check_device(input_nodes)
_check_device(pos_pair_graph)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment