Unverified Commit 1efba850 authored by moto's avatar moto Committed by GitHub
Browse files

Remove deprecated dataset utils (#1826)

parent fc4f481b
import torch
from torchaudio_unittest.common_utils import (
TorchaudioTestCase,
TempDirMixin
)
from torchaudio.datasets import utils as dataset_utils
class Dataset(torch.utils.data.Dataset):
def __getitem__(self, n):
sample_rate = 8000
waveform = n * torch.ones(2, 256)
return waveform, sample_rate
def __len__(self) -> int:
return 2
def __iter__(self):
for i in range(len(self)):
yield self[i]
class TestIterator(TorchaudioTestCase, TempDirMixin):
backend = 'default'
def test_disckcache_iterator(self):
data = dataset_utils.diskcache_iterator(Dataset(), self.get_base_temp_dir())
# Save
data[0]
# Load
data[0]
def test_bg_iterator(self):
data = dataset_utils.bg_iterator(Dataset(), 5)
for _ in data:
pass
from .commonvoice import COMMONVOICE
from .librispeech import LIBRISPEECH
from .speechcommands import SPEECHCOMMANDS
from .utils import bg_iterator, diskcache_iterator
from .vctk import VCTK_092
from .gtzan import GTZAN
from .yesno import YESNO
......@@ -23,7 +22,5 @@ __all__ = [
"CMUARCTIC",
"CMUDict",
"LIBRITTS",
"diskcache_iterator",
"bg_iterator",
"TEDLIUM",
]
......@@ -2,19 +2,13 @@ import hashlib
import logging
import os
import tarfile
import threading
import urllib
import urllib.request
import zipfile
from queue import Queue
from typing import Any, Iterable, List, Optional
import torch
from torch.utils.data import Dataset
from torch.utils.model_zoo import tqdm
from torchaudio._internal.module_utils import deprecated
def stream_url(url: str,
start_byte: Optional[int] = None,
......@@ -203,82 +197,3 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
pass
raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.")
class _DiskCache(Dataset):
"""
Wrap a dataset so that, whenever a new item is returned, it is saved to disk.
"""
def __init__(self, dataset: Dataset, location: str = ".cached") -> None:
self.dataset = dataset
self.location = location
self._id = id(self)
self._cache: List = [None] * len(dataset)
def __getitem__(self, n: int) -> Any:
if self._cache[n]:
f = self._cache[n]
return torch.load(f)
f = str(self._id) + "-" + str(n)
f = os.path.join(self.location, f)
item = self.dataset[n]
self._cache[n] = f
os.makedirs(self.location, exist_ok=True)
torch.save(item, f)
return item
def __len__(self) -> int:
return len(self.dataset)
@deprecated('', version='0.11')
def diskcache_iterator(dataset: Dataset, location: str = ".cached") -> Dataset:
return _DiskCache(dataset, location)
class _ThreadedIterator(threading.Thread):
"""
Prefetch the next queue_length items from iterator in a background thread.
Example:
>> for i in bg_iterator(range(10)):
>> print(i)
"""
class _End:
pass
def __init__(self, generator: Iterable, maxsize: int) -> None:
threading.Thread.__init__(self)
self.queue: Queue = Queue(maxsize)
self.generator = generator
self.daemon = True
self.start()
def run(self) -> None:
for item in self.generator:
self.queue.put(item)
self.queue.put(self._End)
def __iter__(self) -> Any:
return self
def __next__(self) -> Any:
next_item = self.queue.get()
if next_item == self._End:
raise StopIteration
return next_item
# Required for Python 2.7 compatibility
def next(self) -> Any:
return self.__next__()
@deprecated('', version='0.11')
def bg_iterator(iterable: Iterable, maxsize: int) -> Any:
return _ThreadedIterator(iterable, maxsize=maxsize)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment