Unverified Commit 3cd5eed6 authored by Krishna Kalyan's avatar Krishna Kalyan Committed by GitHub
Browse files

Remove redundant dataset utilities (#1086)


Co-authored-by: default avatarkrishnakalyan3 <skalyan@cloudera.com>
parent 7a3e15bc
import os import os
import csv
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import Tuple, Union
...@@ -8,7 +9,6 @@ from torch.utils.data import Dataset ...@@ -8,7 +9,6 @@ from torch.utils.data import Dataset
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
download_url, download_url,
extract_archive, extract_archive,
unicode_csv_reader,
) )
URL = "aew" URL = "aew"
...@@ -154,7 +154,7 @@ class CMUARCTIC(Dataset): ...@@ -154,7 +154,7 @@ class CMUARCTIC(Dataset):
self._text = os.path.join(self._path, self._folder_text, self._file_text) self._text = os.path.join(self._path, self._folder_text, self._file_text)
with open(self._text, "r") as text: with open(self._text, "r") as text:
walker = unicode_csv_reader(text, delimiter="\n") walker = csv.reader(text, delimiter="\n")
self._walker = list(walker) self._walker = list(walker)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
......
import os import os
import csv
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import List, Dict, Tuple, Union, Optional from typing import List, Dict, Tuple, Union, Optional
import torchaudio import torchaudio
from torchaudio.datasets.utils import unicode_csv_reader
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -79,7 +79,7 @@ class COMMONVOICE(Dataset): ...@@ -79,7 +79,7 @@ class COMMONVOICE(Dataset):
self._tsv = os.path.join(self._path, tsv) self._tsv = os.path.join(self._path, tsv)
with open(self._tsv, "r") as tsv_: with open(self._tsv, "r") as tsv_:
walker = unicode_csv_reader(tsv_, delimiter="\t") walker = csv.reader(tsv_, delimiter="\t")
self._header = next(walker) self._header = next(walker)
self._walker = list(walker) self._walker = list(walker)
......
...@@ -4,7 +4,7 @@ from typing import List, Tuple, Union ...@@ -4,7 +4,7 @@ from typing import List, Tuple, Union
from pathlib import Path from pathlib import Path
import torchaudio import torchaudio
from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader from torchaudio.datasets.utils import download_url, extract_archive
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -75,7 +75,7 @@ class LJSPEECH(Dataset): ...@@ -75,7 +75,7 @@ class LJSPEECH(Dataset):
extract_archive(archive) extract_archive(archive)
with open(self._metadata_path, "r", newline='') as metadata: with open(self._metadata_path, "r", newline='') as metadata:
walker = unicode_csv_reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE) walker = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE)
self._walker = list(walker) self._walker = list(walker)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
......
import csv
import errno import errno
import hashlib import hashlib
import logging import logging
...@@ -18,49 +17,6 @@ from torch.utils.data import Dataset ...@@ -18,49 +17,6 @@ from torch.utils.data import Dataset
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
def unicode_csv_reader(unicode_csv_data: TextIOWrapper, **kwargs: Any) -> Any:
r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper.
Borrowed and slightly modified from the Python docs:
https://docs.python.org/2/library/csv.html#csv-examples
Args:
unicode_csv_data (TextIOWrapper): unicode csv data (see example below)
Examples:
>>> from torchaudio.datasets.utils import unicode_csv_reader
>>> import io
>>> with io.open(data_path, encoding="utf8") as f:
>>> reader = unicode_csv_reader(f)
"""
# Fix field larger than field limit error
maxInt = sys.maxsize
while True:
# decrease the maxInt value by factor 10
# as long as the OverflowError occurs.
try:
csv.field_size_limit(maxInt)
break
except OverflowError:
maxInt = int(maxInt / 10)
csv.field_size_limit(maxInt)
for line in csv.reader(unicode_csv_data, **kwargs):
yield line
def makedir_exist_ok(dirpath: str) -> None:
"""
Python2 support for os.makedirs(.., exist_ok=True)
"""
try:
os.makedirs(dirpath)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
def stream_url(url: str, def stream_url(url: str,
start_byte: Optional[int] = None, start_byte: Optional[int] = None,
block_size: int = 32 * 1024, block_size: int = 32 * 1024,
...@@ -305,7 +261,7 @@ class _DiskCache(Dataset): ...@@ -305,7 +261,7 @@ class _DiskCache(Dataset):
item = self.dataset[n] item = self.dataset[n]
self._cache[n] = f self._cache[n] = f
makedir_exist_ok(self.location) os.makedirs(self.location, exist_ok=True)
torch.save(item, f) torch.save(item, f)
return item return item
......
...@@ -226,7 +226,7 @@ def apply_effects_file( ...@@ -226,7 +226,7 @@ def apply_effects_file(
... self.sample_rate = sample_rate ... self.sample_rate = sample_rate
... ...
... def __getitem__(self, index): ... def __getitem__(self, index):
... speed = 0.5 + 1.5 * torch.rand() ... speed = 0.5 + 1.5 * random.randn()
... effects = [ ... effects = [
... ['gain', '-n', '-10'], # apply 10 db attenuation ... ['gain', '-n', '-10'], # apply 10 db attenuation
... ['remix', '-'], # merge all the channels ... ['remix', '-'], # merge all the channels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment