"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5a6edac087915c7a92f3317067e82c1097b98307"
Unverified Commit bfef789e authored by Justus Schock's avatar Justus Schock Committed by GitHub
Browse files

[Dataloading] Make loader iters iterator (#2886)



* Make loader items iterator

* Update test_dataloader.py

* Update __init__.py

* Update test_dataloader.py
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent fbfcf1a8
...@@ -15,6 +15,9 @@ class _ScalarDataBatcherIter: ...@@ -15,6 +15,9 @@ class _ScalarDataBatcherIter:
self.index = 0 self.index = 0
self.drop_last = drop_last self.drop_last = drop_last
def __iter__(self):
return self
def __next__(self): def __next__(self):
num_items = self.dataset.shape[0] num_items = self.dataset.shape[0]
if self.index >= num_items: if self.index >= num_items:
...@@ -214,6 +217,9 @@ class _NodeDataLoaderIter: ...@@ -214,6 +217,9 @@ class _NodeDataLoaderIter:
self.node_dataloader = node_dataloader self.node_dataloader = node_dataloader
self.iter_ = iter(node_dataloader.dataloader) self.iter_ = iter(node_dataloader.dataloader)
def __iter__(self):
return self
def __next__(self): def __next__(self):
# input_nodes, output_nodes, blocks # input_nodes, output_nodes, blocks
result_ = next(self.iter_) result_ = next(self.iter_)
...@@ -228,6 +234,9 @@ class _EdgeDataLoaderIter: ...@@ -228,6 +234,9 @@ class _EdgeDataLoaderIter:
self.edge_dataloader = edge_dataloader self.edge_dataloader = edge_dataloader
self.iter_ = iter(edge_dataloader.dataloader) self.iter_ = iter(edge_dataloader.dataloader)
def __iter__(self):
return self
def __next__(self): def __next__(self):
result_ = next(self.iter_) result_ = next(self.iter_)
......
...@@ -3,6 +3,7 @@ import backend as F ...@@ -3,6 +3,7 @@ import backend as F
import unittest import unittest
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator
from itertools import product from itertools import product
def _check_neighbor_sampling_dataloader(g, nids, dl, mode, collator): def _check_neighbor_sampling_dataloader(g, nids, dl, mode, collator):
...@@ -179,6 +180,7 @@ def test_neighbor_sampler_dataloader(): ...@@ -179,6 +180,7 @@ def test_neighbor_sampler_dataloader():
for _g, nid, collator, mode in zip(graphs, nids, collators, modes): for _g, nid, collator, mode in zip(graphs, nids, collators, modes):
dl = DataLoader( dl = DataLoader(
collator.dataset, collate_fn=collator.collate, batch_size=2, shuffle=True, drop_last=False) collator.dataset, collate_fn=collator.collate, batch_size=2, shuffle=True, drop_last=False)
assert isinstance(iter(dl), Iterator)
_check_neighbor_sampling_dataloader(_g, nid, dl, mode, collator) _check_neighbor_sampling_dataloader(_g, nid, dl, mode, collator)
def test_graph_dataloader(): def test_graph_dataloader():
...@@ -186,6 +188,7 @@ def test_graph_dataloader(): ...@@ -186,6 +188,7 @@ def test_graph_dataloader():
num_batches = 2 num_batches = 2
minigc_dataset = dgl.data.MiniGCDataset(batch_size * num_batches, 10, 20) minigc_dataset = dgl.data.MiniGCDataset(batch_size * num_batches, 10, 20)
data_loader = dgl.dataloading.GraphDataLoader(minigc_dataset, batch_size=batch_size, shuffle=True) data_loader = dgl.dataloading.GraphDataLoader(minigc_dataset, batch_size=batch_size, shuffle=True)
assert isinstance(iter(data_loader), Iterator)
for graph, label in data_loader: for graph, label in data_loader:
assert isinstance(graph, dgl.DGLGraph) assert isinstance(graph, dgl.DGLGraph)
assert F.asnumpy(label).shape[0] == batch_size assert F.asnumpy(label).shape[0] == batch_size
...@@ -226,6 +229,7 @@ def test_node_dataloader(): ...@@ -226,6 +229,7 @@ def test_node_dataloader():
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes}, g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), batch_size=batch_size) sampler, device=F.ctx(), batch_size=batch_size)
assert isinstance(iter(dataloader), Iterator)
for input_nodes, output_nodes, blocks in dataloader: for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes) _check_device(input_nodes)
_check_device(output_nodes) _check_device(output_nodes)
...@@ -280,6 +284,8 @@ def test_edge_dataloader(): ...@@ -280,6 +284,8 @@ def test_edge_dataloader():
g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes}, g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
sampler, device=F.ctx(), negative_sampler=neg_sampler, sampler, device=F.ctx(), negative_sampler=neg_sampler,
batch_size=batch_size) batch_size=batch_size)
assert isinstance(iter(dataloader), Iterator)
for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
_check_device(input_nodes) _check_device(input_nodes)
_check_device(pos_pair_graph) _check_device(pos_pair_graph)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment