"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "eb4131246e83be4110991167109b58a515f23162"
Commit 0f833526 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add BufferedIterator (#419)

Summary:
This improves performance for datasets that load data lazily. Enabled by default since it shouldn't compromise performance for non-lazy datasets.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/419

Differential Revision: D13546585

Pulled By: myleott

fbshipit-source-id: f6152e2047291b0d68cd7506cd772b0caafe95be
parent 9ca82a0e
......@@ -7,6 +7,8 @@
import itertools
import math
import queue
import threading
import numpy as np
import torch
......@@ -67,14 +69,18 @@ class EpochBatchIterator(object):
batch_sampler (~torch.utils.data.Sampler): an iterator over batches of
indices
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
reproducibility. Default: 1
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shards. Default: 1
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
return. Default: 0
buffer_size (int, optional): number of batches to buffer. Default: 5
"""
def __init__(self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0):
def __init__(
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
buffer_size=5,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
self.collate_fn = collate_fn
......@@ -82,6 +88,7 @@ class EpochBatchIterator(object):
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
self.buffer_size = buffer_size
self.epoch = 0
self._cur_epoch_itr = None
......@@ -172,13 +179,50 @@ class EpochBatchIterator(object):
batches = self.frozen_batches
batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches,
return CountingIterator(BufferedIterator(
torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches,
),
buffer_size=self.buffer_size,
))
class BufferedIterator(object):
"""Wrapper around an iterable that prefetches items into a buffer.
Args:
iterable (iterable): iterable to wrap
buffer_size (int): number of items to prefetch and buffer
"""
def __init__(self, iterable, buffer_size):
self.iterable = iterable
self.q = queue.Queue(maxsize=buffer_size)
self.thread = threading.Thread(target=self._load_q, daemon=True)
self.thread.start()
def __len__(self):
return len(self.iterable)
def __iter__(self):
return self
def __next__(self):
x = self.q.get()
if x is None:
self.thread.join()
raise StopIteration
return x[0]
def _load_q(self):
for x in self.iterable:
self.q.put([x]) # wrap in list so that it's never None
self.q.put(None)
class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment