Unverified Commit 6db2ad1c authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Background generator (#323)

* BackgroundGenerator
* renaming disk cache.
parent 99c52600
client_id path sentence up_votes down_votes age gender accent
00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 common_voice_tt_00000000.mp3 test. 1 0 thirties female
00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001 common_voice_tt_00000000.mp3 test. 1 0 thirties female
......@@ -3,7 +3,7 @@ import unittest
from torchaudio.datasets.commonvoice import COMMONVOICE
from torchaudio.datasets.librispeech import LIBRISPEECH
from torchaudio.datasets.utils import DiskCache
from torchaudio.datasets.utils import diskcache_iterator, bg_iterator
from torchaudio.datasets.vctk import VCTK
from torchaudio.datasets.yesno import YESNO
......@@ -34,12 +34,19 @@ class TestDatasets(unittest.TestCase):
def test_commonvoice_diskcache(self):
path = os.path.join(self.path, "commonvoice")
data = COMMONVOICE(path, "train.tsv", "tatar")
data = DiskCache(data)
data = diskcache_iterator(data)
# Save
data[0]
# Load
data[0]
def test_commonvoice_bg(self):
path = os.path.join(self.path, "commonvoice")
data = COMMONVOICE(path, "train.tsv", "tatar")
data = bg_iterator(data, 5)
for d in data:
pass
if __name__ == "__main__":
unittest.main()
from .commonvoice import COMMONVOICE
from .librispeech import LIBRISPEECH
from .utils import bg_iterator, diskcache_iterator
from .vctk import VCTK
from .yesno import YESNO
from .utils import DiskCache
__all__ = ("COMMONVOICE", "LIBRISPEECH", "VCTK", "YESNO", "DiskCache")
__all__ = (
"COMMONVOICE",
"LIBRISPEECH",
"VCTK",
"YESNO",
"diskcache_iterator",
"bg_iterator",
)
......@@ -6,7 +6,9 @@ import logging
import os
import sys
import tarfile
import threading
import zipfile
from queue import Queue
import six
import torch
......@@ -192,7 +194,7 @@ def walk_files(root, suffix, prefix=False, remove_suffix=False):
yield f
class DiskCache(Dataset):
class _DiskCache(Dataset):
"""
Wrap a dataset so that, whenever a new item is returned, it is saved to disk.
"""
......@@ -221,3 +223,45 @@ class DiskCache(Dataset):
def __len__(self):
return len(self.dataset)
def diskcache_iterator(dataset, location=".cached"):
return _DiskCache(dataset, location)
class _ThreadedIterator(threading.Thread):
"""
Prefetch the next queue_length items from iterator in a background thread.
Example:
>> for i in bg_iterator(range(10)):
>> print(i)
"""
class _End:
pass
def __init__(self, generator, maxsize):
threading.Thread.__init__(self)
self.queue = Queue(maxsize)
self.generator = generator
self.daemon = True
self.start()
def run(self):
for item in self.generator:
self.queue.put(item)
self.queue.put(self._End)
def __iter__(self):
return self
def __next__(self):
next_item = self.queue.get()
if next_item == self._End:
raise StopIteration
return next_item
def bg_iterator(iterable, maxsize):
return _ThreadedIterator(iterable, maxsize=maxsize)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment