Unverified Commit 6db2ad1c authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Background generator (#323)

* BackgroundGenerator
* renaming disk cache.
parent 99c52600
client_id path sentence up_votes down_votes age gender accent client_id path sentence up_votes down_votes age gender accent
00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 common_voice_tt_00000000.mp3 test. 1 0 thirties female 00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 common_voice_tt_00000000.mp3 test. 1 0 thirties female
00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001 common_voice_tt_00000000.mp3 test. 1 0 thirties female
...@@ -3,7 +3,7 @@ import unittest ...@@ -3,7 +3,7 @@ import unittest
from torchaudio.datasets.commonvoice import COMMONVOICE from torchaudio.datasets.commonvoice import COMMONVOICE
from torchaudio.datasets.librispeech import LIBRISPEECH from torchaudio.datasets.librispeech import LIBRISPEECH
from torchaudio.datasets.utils import DiskCache from torchaudio.datasets.utils import diskcache_iterator, bg_iterator
from torchaudio.datasets.vctk import VCTK from torchaudio.datasets.vctk import VCTK
from torchaudio.datasets.yesno import YESNO from torchaudio.datasets.yesno import YESNO
...@@ -34,12 +34,19 @@ class TestDatasets(unittest.TestCase): ...@@ -34,12 +34,19 @@ class TestDatasets(unittest.TestCase):
def test_commonvoice_diskcache(self): def test_commonvoice_diskcache(self):
path = os.path.join(self.path, "commonvoice") path = os.path.join(self.path, "commonvoice")
data = COMMONVOICE(path, "train.tsv", "tatar") data = COMMONVOICE(path, "train.tsv", "tatar")
data = DiskCache(data) data = diskcache_iterator(data)
# Save # Save
data[0] data[0]
# Load # Load
data[0] data[0]
def test_commonvoice_bg(self):
path = os.path.join(self.path, "commonvoice")
data = COMMONVOICE(path, "train.tsv", "tatar")
data = bg_iterator(data, 5)
for d in data:
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
from .commonvoice import COMMONVOICE from .commonvoice import COMMONVOICE
from .librispeech import LIBRISPEECH from .librispeech import LIBRISPEECH
from .utils import bg_iterator, diskcache_iterator
from .vctk import VCTK from .vctk import VCTK
from .yesno import YESNO from .yesno import YESNO
from .utils import DiskCache
__all__ = ("COMMONVOICE", "LIBRISPEECH", "VCTK", "YESNO", "DiskCache") __all__ = (
"COMMONVOICE",
"LIBRISPEECH",
"VCTK",
"YESNO",
"diskcache_iterator",
"bg_iterator",
)
...@@ -6,7 +6,9 @@ import logging ...@@ -6,7 +6,9 @@ import logging
import os import os
import sys import sys
import tarfile import tarfile
import threading
import zipfile import zipfile
from queue import Queue
import six import six
import torch import torch
...@@ -192,7 +194,7 @@ def walk_files(root, suffix, prefix=False, remove_suffix=False): ...@@ -192,7 +194,7 @@ def walk_files(root, suffix, prefix=False, remove_suffix=False):
yield f yield f
class DiskCache(Dataset): class _DiskCache(Dataset):
""" """
Wrap a dataset so that, whenever a new item is returned, it is saved to disk. Wrap a dataset so that, whenever a new item is returned, it is saved to disk.
""" """
...@@ -221,3 +223,45 @@ class DiskCache(Dataset): ...@@ -221,3 +223,45 @@ class DiskCache(Dataset):
def __len__(self): def __len__(self):
return len(self.dataset) return len(self.dataset)
def diskcache_iterator(dataset, location=".cached"):
return _DiskCache(dataset, location)
class _ThreadedIterator(threading.Thread):
"""
Prefetch the next queue_length items from iterator in a background thread.
Example:
>> for i in bg_iterator(range(10)):
>> print(i)
"""
class _End:
pass
def __init__(self, generator, maxsize):
threading.Thread.__init__(self)
self.queue = Queue(maxsize)
self.generator = generator
self.daemon = True
self.start()
def run(self):
for item in self.generator:
self.queue.put(item)
self.queue.put(self._End)
def __iter__(self):
return self
def __next__(self):
next_item = self.queue.get()
if next_item == self._End:
raise StopIteration
return next_item
def bg_iterator(iterable, maxsize):
return _ThreadedIterator(iterable, maxsize=maxsize)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment