Unverified Commit 8920802d authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

new dataset format with librispeech and commonvoice (#303)

* new dataset format.
* add basic test.
* files for testing.
* serialization using torch.
* add diskcache.
* adding deprecation warnings.
* removing legacy.
* warning about transforms.
* detecting file format using reader.
parent b8203182
client_id path sentence up_votes down_votes age gender accent
00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 common_voice_tt_00000000.mp3 test. 1 0 thirties female
import os
import unittest
from torchaudio.datasets.commonvoice import COMMONVOICE
from torchaudio.datasets.librispeech import LIBRISPEECH
from torchaudio.datasets.utils import DiskCache
from torchaudio.datasets.vctk import VCTK
from torchaudio.datasets.yesno import YESNO
class TestDatasets(unittest.TestCase):
path = "assets"
def test_yesno(self):
data = YESNO(self.path, return_dict=True)
data[0]
def test_vctk(self):
data = VCTK(self.path, return_dict=True)
data[0]
def test_librispeech(self):
data = LIBRISPEECH(self.path, "dev-clean")
data[0]
def test_commonvoice(self):
path = os.path.join(self.path, "commonvoice")
data = COMMONVOICE(path, "train.tsv", "tatar")
data[0]
def test_commonvoice_diskcache(self):
path = os.path.join(self.path, "commonvoice")
data = COMMONVOICE(path, "train.tsv", "tatar")
data = DiskCache(data)
# Save
data[0]
# Load
data[0]
if __name__ == "__main__":
unittest.main()
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import torch
import torchaudio
import unittest
import common_utils
import torchaudio.datasets.vctk as vctk
class TestVCTK(unittest.TestCase):
def setUp(self):
self.test_dirpath, self.test_dir = common_utils.create_temp_assets_dir()
def get_full_path(self, file):
return os.path.join(self.test_dirpath, 'assets', file)
def test_is_audio_file(self):
self.assertTrue(vctk.is_audio_file('foo.wav'))
self.assertTrue(vctk.is_audio_file('foo.WAV'))
self.assertFalse(vctk.is_audio_file('foo.bar'))
def test_make_manifest(self):
audios = vctk.make_manifest(self.test_dirpath)
files = ['kaldi_file.wav', 'kaldi_file_8000.wav',
'sinewave.wav', 'steam-train-whistle-daniel_simon.mp3',
'dtmf_30s_stereo.mp3', 'whitenoise_1min.mp3', 'whitenoise.mp3']
files = [self.get_full_path(file) for file in files]
files.sort()
audios.sort()
self.assertEqual(files, audios, msg='files %s did not match audios %s' % (files, audios))
def test_read_audio_downsample_false(self):
file = self.get_full_path('kaldi_file.wav')
s, sr = vctk.read_audio(file, downsample=False)
self.assertEqual(sr, 16000, msg='incorrect sample rate %d' % (sr))
self.assertEqual(s.shape, (1, 20), msg='incorrect shape %s' % (str(s.shape)))
def test_read_audio_downsample_true(self):
file = self.get_full_path('kaldi_file.wav')
s, sr = vctk.read_audio(file, downsample=True)
self.assertEqual(sr, 16000, msg='incorrect sample rate %d' % (sr))
self.assertEqual(s.shape, (1, 20), msg='incorrect shape %s' % (str(s.shape)))
def test_load_txts(self):
utterences = vctk.load_txts(self.test_dirpath)
expected_utterances = {'file2': 'word5 word6\n', 'file1': 'word1 word2\n'}
self.assertEqual(utterences, expected_utterances,
msg='%s did not match %s' % (utterences, expected_utterances))
def test_vctk(self):
# TODO somehow test download=True, the dataset is too big download ~10 GB for
# each test so need a way to mock it
self.assertRaises(RuntimeError, vctk.VCTK, self.test_dirpath, download=False)
if __name__ == '__main__':
unittest.main()
from .yesno import YESNO
from .commonvoice import COMMONVOICE
from .librispeech import LIBRISPEECH
from .vctk import VCTK
from .yesno import YESNO
from .utils import DiskCache
__all__ = ('YESNO', 'VCTK')
__all__ = ("COMMONVOICE", "LIBRISPEECH", "VCTK", "YESNO", "DiskCache")
import os
import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader
# Default TSV should be one of
# dev.tsv
# invalidated.tsv
# other.tsv
# test.tsv
# train.tsv
# validated.tsv
URL = "english"
TSV = "train.tsv"
def load_commonvoice_item(line, header, path, folder_audio):
fileid = line[1]
filename = os.path.join(path, folder_audio, fileid)
waveform, sample_rate = torchaudio.load(filename)
dic = dict(zip(header, line))
dic["waveform"] = waveform
dic["sample_rate"] = sample_rate
return dic
class COMMONVOICE(Dataset):
_ext_txt = ".txt"
_ext_audio = ".mp3"
_folder_audio = "clips"
def __init__(self, root, tsv=TSV, url=URL, download=False):
languages = {
"tatar": "tt",
"english": "en",
"german": "de",
"french": "fr",
"welsh": "cy",
"breton": "br",
"chuvash": "cv",
"turkish": "tr",
"kyrgyz": "ky",
"irish": "ga-IE",
"kabyle": "kab",
"catalan": "ca",
"taiwanese": "zh-TW",
"slovenian": "sl",
"italian": "it",
"dutch": "nl",
"hakha chin": "cnh",
"esperanto": "eo",
"estonian": "et",
"persian": "fa",
"basque": "eu",
"spanish": "es",
"chinese": "zh-CN",
"mongolian": "mn",
"sakha": "sah",
"dhivehi": "dv",
"kinyarwanda": "rw",
"swedish": "sv-SE",
"russian": "ru",
}
if url is languages:
ext_archive = ".tar.gz"
language = languages[url]
base_url = (
"https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4"
+ ".s3.amazonaws.com/cv-corpus-3/"
)
url = base_url + language + ext_archive
archive = os.path.basename(url)
archive = os.path.join(root, archive)
self._path = root
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
download_url(url, root)
extract_archive(archive)
self._tsv = os.path.join(root, tsv)
with open(self._tsv, "r") as tsv:
walker = unicode_csv_reader(tsv, delimiter="\t")
self._header = next(walker)
self._walker = list(walker)
def __getitem__(self, n):
line = self._walker[n]
return load_commonvoice_item(line, self._header, self._path, self._folder_audio)
def __len__(self):
return len(self._walker)
import os
import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
download_url,
extract_archive,
unicode_csv_reader,
walk_files,
)
URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriSpeech"
def load_librispeech_item(fileid, path, ext_audio, ext_txt):
speaker, chapter, utterance = fileid.split("-")
file_text = speaker + "-" + chapter + ext_txt
file_text = os.path.join(path, speaker, chapter, file_text)
file_audio = speaker + "-" + chapter + "-" + utterance + ext_audio
file_audio = os.path.join(path, speaker, chapter, file_audio)
# Load audio
waveform, sample_rate = torchaudio.load(file_audio)
# Load text
for line in open(file_text):
fileid_text, content = line.strip().split(" ", 1)
if file_audio == fileid_text:
break
else:
# Translation not found
raise ValueError
return {
"speaker_id": speaker,
"chapter_id": chapter,
"utterance_id": utterance,
"utterance": content,
"waveform": waveform,
"sample_rate": sample_rate,
}
class LIBRISPEECH(Dataset):
_ext_txt = ".trans.txt"
_ext_audio = ".flac"
def __init__(
self, root, url=URL, folder_in_archive=FOLDER_IN_ARCHIVE, download=False
):
if url in [
"dev-clean",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
]:
ext_archive = ".tar.gz"
base_url = "http://www.openslr.org/resources/12/"
url = os.path.join(base_url, url + ext_archive)
basename = os.path.basename(url)
archive = os.path.join(root, basename)
basename = basename.split(".")[0]
folder_in_archive = os.path.join(folder_in_archive, basename)
self._path = os.path.join(root, folder_in_archive)
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
download_url(url, root)
extract_archive(archive)
walker = walk_files(
self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True
)
self._walker = list(walker)
def __getitem__(self, n):
fileid = self._walker[n]
return load_librispeech_item(fileid, self._path, self._ext_audio, self._ext_txt)
def __len__(self):
return len(self._walker)
import csv
import errno
import gzip
import hashlib
import logging
import os
import sys
import tarfile
import zipfile
import six
import torch
import torchaudio
from torch.utils.data import Dataset
from torch.utils.model_zoo import tqdm
def unicode_csv_reader(unicode_csv_data, **kwargs):
r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper.
Borrowed and slightly modified from the Python docs:
https://docs.python.org/2/library/csv.html#csv-examples
Arguments:
unicode_csv_data: unicode csv data (see example below)
Examples:
>>> from torchaudio.datasets.utils import unicode_csv_reader
>>> import io
>>> with io.open(data_path, encoding="utf8") as f:
>>> reader = unicode_csv_reader(f)
"""
# Fix field larger than field limit error
maxInt = sys.maxsize
while True:
# decrease the maxInt value by factor 10
# as long as the OverflowError occurs.
try:
csv.field_size_limit(maxInt)
break
except OverflowError:
maxInt = int(maxInt / 10)
csv.field_size_limit(maxInt)
if six.PY2:
# csv.py doesn't do Unicode; encode temporarily as UTF-8:
csv_reader = csv.reader(utf_8_encoder(unicode_csv_data), **kwargs)
for row in csv_reader:
# decode UTF-8 back to Unicode, cell by cell:
yield [cell.decode("utf-8") for cell in row]
else:
for line in csv.reader(unicode_csv_data, **kwargs):
yield line
def gen_bar_updater():
pbar = tqdm(total=None)
def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)
return bar_update
def makedir_exist_ok(dirpath):
"""
Python2 support for os.makedirs(.., exist_ok=True)
"""
try:
os.makedirs(dirpath)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
def download_url(url, root, filename=None, md5=None):
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
from six.moves import urllib
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
makedir_exist_ok(root)
# downloads file
if os.path.isfile(fpath):
print("Using downloaded file: " + fpath)
else:
try:
print("Downloading " + url + " to " + fpath)
urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
except (urllib.error.URLError, IOError) as e:
if url[:5] == "https":
url = url.replace("https:", "http:")
print(
"Failed download. Trying https -> http instead."
" Downloading " + url + " to " + fpath
)
urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
else:
raise e
def extract_archive(from_path, to_path=None, overwrite=False):
"""Extract archive.
Arguments:
from_path: the path of the archive.
to_path: the root path of the extraced files (directory of from_path)
overwrite: overwrite existing files (False)
Returns:
List of paths to extracted files even if not overwritten.
Examples:
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> from_path = './validation.tar.gz'
>>> to_path = './'
>>> torchaudio.datasets.utils.download_from_url(url, from_path)
>>> torchaudio.datasets.utils.extract_archive(from_path, to_path)
"""
if to_path is None:
to_path = os.path.dirname(from_path)
try:
with tarfile.open(from_path, "r") as tar:
logging.info("Opened tar file {}.".format(from_path))
files = []
for file_ in tar:
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
logging.info("{} already extracted.".format(file_path))
if not overwrite:
continue
tar.extract(file_, to_path)
return files
except tarfile.ReadError:
pass
try:
with zipfile.ZipFile(from_path, "r") as zfile:
logging.info("Opened zip file {}.".format(from_path))
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
logging.info("{} already extracted.".format(file_path))
if not overwrite:
continue
zfile.extract(file_, to_path)
return files
except zipfile.BadZipFile:
pass
raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.")
def walk_files(root, suffix, prefix=False, remove_suffix=False):
"""List recursively all files ending with a suffix at a given root
Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
for _, _, fn in os.walk(root):
for f in fn:
if f.endswith(suffix):
if remove_suffix:
f = f[: -len(suffix)]
if prefix:
f = os.path.join(root, f)
yield f
class DiskCache(Dataset):
"""
Wrap a dataset so that, whenever a new item is returned, it is saved to disk.
"""
def __init__(self, dataset, location=".cached"):
self.dataset = dataset
self.location = location
self._id = id(self)
self._cache = [None] * len(dataset)
def __getitem__(self, n):
if self._cache[n]:
f = self._cache[n]
return torch.load(f)
f = str(self._id) + "-" + str(n)
f = os.path.join(self.location, f)
item = self.dataset[n]
self._cache[n] = f
makedir_exist_ok(self.location)
torch.save(item, f)
return item
def __len__(self):
return len(self.dataset)
from __future__ import absolute_import, division, print_function, unicode_literals
import torch.utils.data as data
import os
import os.path
import shutil
import errno
import torch
import torchaudio
AUDIO_EXTENSIONS = [
'.wav', '.mp3', '.flac', '.sph', '.ogg', '.opus',
'.WAV', '.MP3', '.FLAC', '.SPH', '.OGG', '.OPUS',
]
import warnings
def is_audio_file(filename):
return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS)
import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive, walk_files
URL = "http://homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"
FOLDER_IN_ARCHIVE = "VCTK-Corpus"
def make_manifest(dir):
audios = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in fnames:
if is_audio_file(fname):
path = os.path.join(root, fname)
item = path
audios.append(item)
return audios
def load_vctk_item(
fileid, path, ext_audio, ext_txt, folder_audio, folder_txt, downsample=False
):
speaker, utterance = fileid.split("_")
# Read text
file_txt = os.path.join(path, folder_txt, speaker, fileid + ext_txt)
with open(file_txt) as file_text:
content = file_text.readlines()[0]
def read_audio(fp, downsample=True):
# Read wav
file_audio = os.path.join(path, folder_audio, speaker, fileid + ext_audio)
if downsample:
# Legacy
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(fp)
E.set_input_file(file_audio)
E.append_effect_to_chain("gain", ["-h"])
E.append_effect_to_chain("channels", [1])
E.append_effect_to_chain("rate", [16000])
E.append_effect_to_chain("gain", ["-rh"])
E.append_effect_to_chain("dither", ["-s"])
sig, sr = E.sox_build_flow_effects()
waveform, sample_rate = E.sox_build_flow_effects()
else:
sig, sr = torchaudio.load(fp)
sig = sig.contiguous()
return sig, sr
def load_txts(dir):
"""Create a dictionary with all the text of the audio transcriptions."""
utterences = dict()
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in fnames:
if fname.endswith(".txt"):
with open(os.path.join(root, fname), "r") as f:
fname_no_ext = os.path.basename(
fname).rsplit(".", 1)[0]
utterences[fname_no_ext] = f.readline()
return utterences
class VCTK(data.Dataset):
r"""`VCTK <http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html>`_ Dataset.
`alternate url <http://datashare.is.ed.ac.uk/handle/10283/2651>`_
Args:
root (str): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
downsample (bool, optional): Whether to downsample the signal (Default: ``True``)
transform (Callable, optional): A function/transform that takes in an raw audio
and returns a transformed version. E.g, ``transforms.Spectrogram``. (Default: ``None``)
target_transform (callable, optional): A function/transform that takes in the
target and transforms it. (Default: ``None``)
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. (Default: ``True``)
dev_mode(bool, optional): If true, clean up is not performed on downloaded
files. Useful to keep raw audio and transcriptions. (Default: ``False``)
"""
raw_folder = 'vctk/raw'
processed_folder = 'vctk/processed'
url = 'http://homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz'
dset_path = 'VCTK-Corpus'
def __init__(self, root, downsample=True, transform=None, target_transform=None, download=False, dev_mode=False):
self.root = os.path.expanduser(root)
waveform, sample_rate = torchaudio.load(file_audio)
return {
"speaker_id": speaker,
"utterance_id": utterance,
"utterance": content,
"waveform": waveform,
"sample_rate": sample_rate,
}
class VCTK(Dataset):
_folder_txt = "txt"
_folder_audio = "wav48"
_ext_txt = ".txt"
_ext_audio = ".wav"
def __init__(
self,
root,
url=URL,
folder_in_archive=FOLDER_IN_ARCHIVE,
download=False,
downsample=False,
transform=None,
target_transform=None,
return_dict=False,
):
if not return_dict:
warnings.warn(
"In the next version, the item returned will be a dictionary. "
"Please use `return_dict=True` to enable this behavior now, "
"and suppress this warning.",
DeprecationWarning,
)
if downsample:
warnings.warn(
"In the next version, transforms will not be part of the dataset. "
"Please use `downsample=False` to enable this behavior now, ",
"and suppress this warning.",
DeprecationWarning,
)
if transform is not None or target_transform is not None:
warnings.warn(
"In the next version, transforms will not be part of the dataset. "
"Please remove the option `transform=True` and "
"`target_transform=True` to suppress this warning.",
DeprecationWarning,
)
self.downsample = downsample
self.transform = transform
self.target_transform = target_transform
self.dev_mode = dev_mode
self.data = []
self.labels = []
self.chunk_size = 1000
self.num_samples = 0
self.max_len = 0
self.cached_pt = 0
self.return_dict = return_dict
archive = os.path.basename(url)
archive = os.path.join(root, archive)
self._path = os.path.join(root, folder_in_archive)
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
self._read_info()
self.data, self.labels = torch.load(os.path.join(
self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt)))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
Tuple[torch.Tensor, int]: The output tuple (image, target) where target
is index of the target class.
"""
if self.cached_pt != index // self.chunk_size:
self.cached_pt = int(index // self.chunk_size)
self.data, self.labels = torch.load(os.path.join(
self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt)))
index = index % self.chunk_size
audio, target = self.data[index], self.labels[index]
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
download_url(url, root)
extract_archive(archive)
if not os.path.isdir(self._path):
raise RuntimeError(
"Dataset not found. Please use `download=True` to download it."
)
walker = walk_files(
self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True
)
self._walker = list(walker)
def __getitem__(self, n):
fileid = self._walker[n]
item = load_vctk_item(
fileid,
self._path,
self._ext_audio,
self._ext_txt,
self._folder_audio,
self._folder_txt,
)
# Legacy
waveform = item["waveform"]
if self.transform is not None:
audio = self.transform(audio)
waveform = self.transform(waveform)
item["waveform"] = waveform
# Legacy
utterance = item["utterance"]
if self.target_transform is not None:
target = self.target_transform(target)
utterance = self.target_transform(utterance)
item["utterance"] = utterance
if self.return_dict:
return item
return audio, target
# Legacy
return item["waveform"], item["utterance"]
def __len__(self):
return self.num_samples
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, "vctk_info.txt"))
def _write_info(self, num_items):
info_path = os.path.join(
self.root, self.processed_folder, "vctk_info.txt")
with open(info_path, "w") as f:
f.write("num_samples,{}\n".format(num_items))
f.write("max_len,{}\n".format(self.max_len))
def _read_info(self):
info_path = os.path.join(
self.root, self.processed_folder, "vctk_info.txt")
with open(info_path, "r") as f:
self.num_samples = int(f.readline().split(",")[1])
self.max_len = int(f.readline().split(",")[1])
def download(self):
"""Download the VCTK data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
raw_abs_dir = os.path.join(self.root, self.raw_folder)
processed_abs_dir = os.path.join(self.root, self.processed_folder)
dset_abs_path = os.path.join(
self.root, self.raw_folder, self.dset_path)
# download files
try:
os.makedirs(os.path.join(self.root, self.processed_folder))
os.makedirs(os.path.join(self.root, self.raw_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
url = self.url
print('Downloading ' + url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
if not os.path.isfile(file_path):
urllib.request.urlretrieve(url, file_path)
if not os.path.exists(dset_abs_path):
with tarfile.open(file_path) as zip_f:
zip_f.extractall(raw_abs_dir)
else:
print("Using existing raw folder")
if not self.dev_mode:
os.unlink(file_path)
# process and save as torch files
torchaudio.initialize_sox()
print('Processing...')
shutil.copyfile(
os.path.join(dset_abs_path, "COPYING"),
os.path.join(processed_abs_dir, "VCTK_COPYING")
)
audios = make_manifest(dset_abs_path)
utterences = load_txts(dset_abs_path)
self.max_len = 0
print("Found {} audio files and {} utterences".format(
len(audios), len(utterences)))
for n in range(len(audios) // self.chunk_size + 1):
tensors = []
labels = []
lengths = []
st_idx = n * self.chunk_size
end_idx = st_idx + self.chunk_size
for i, f in enumerate(audios[st_idx:end_idx]):
txt_dir = os.path.dirname(f).replace("wav48", "txt")
if os.path.exists(txt_dir):
f_rel_no_ext = os.path.basename(f).rsplit(".", 1)[0]
sig = read_audio(f, downsample=self.downsample)[0]
tensors.append(sig)
lengths.append(sig.size(1))
labels.append(utterences[f_rel_no_ext])
self.max_len = sig.size(1) if sig.size(
1) > self.max_len else self.max_len
# sort sigs/labels: longest -> shortest
tensors, labels = zip(*[(b, c) for (a, b, c) in sorted(
zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)])
data = (tensors, labels)
torch.save(
data,
os.path.join(
self.root,
self.processed_folder,
"vctk_{:04d}.pt".format(n)
)
)
self._write_info((n * self.chunk_size) + i + 1)
if not self.dev_mode:
shutil.rmtree(raw_abs_dir, ignore_errors=True)
torchaudio.shutdown_sox()
print('Done!')
return len(self._walker)
from __future__ import absolute_import, division, print_function, unicode_literals
import torch.utils.data as data
import os
import os.path
import shutil
import errno
import torch
import warnings
import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive, walk_files
URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz"
FOLDER_IN_ARCHIVE = "waves_yesno"
def load_yesno_item(fileid, path, ext_audio):
# Read label
label = fileid.split("_")
# Read wav
file_audio = os.path.join(path, fileid + ext_audio)
waveform, sample_rate = torchaudio.load(file_audio)
return {"label": label, "waveform": waveform, "sample_rate": sample_rate}
class YESNO(Dataset):
_ext_audio = ".wav"
def __init__(
self,
root,
url=URL,
folder_in_archive=FOLDER_IN_ARCHIVE,
download=False,
transform=None,
target_transform=None,
return_dict=False,
):
if not return_dict:
warnings.warn(
"In the next version, the item returned will be a dictionary. "
"Please use `return_dict=True` to enable this behavior now, "
"and suppress this warning.",
DeprecationWarning,
)
if transform is not None or target_transform is not None:
warnings.warn(
"In the next version, transforms will not be part of the dataset. "
"Please remove the option `transform=True` and "
"`target_transform=True` to suppress this warning.",
DeprecationWarning,
)
class YESNO(data.Dataset):
r"""`YesNo Hebrew <http://www.openslr.org/1/>`_ Dataset.
Args:
root (str): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
transform (Callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.Spectrogram``. (
Default: ``None``)
target_transform (Callable, optional): A function/transform that takes in the
target and transforms it. (Default: ``None``)
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. (Default: ``False``)
dev_mode(bool, optional): If true, clean up is not performed on downloaded
files. Useful to keep raw audio and transcriptions. (Default: ``False``)
"""
raw_folder = 'yesno/raw'
processed_folder = 'yesno/processed'
url = 'http://www.openslr.org/resources/1/waves_yesno.tar.gz'
dset_path = 'waves_yesno'
processed_file = 'yesno.pt'
def __init__(self, root, transform=None, target_transform=None, download=False, dev_mode=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.dev_mode = dev_mode
self.data = []
self.labels = []
self.num_samples = 0
self.max_len = 0
self.return_dict = return_dict
if download:
self.download()
archive = os.path.basename(url)
archive = os.path.join(root, archive)
self._path = os.path.join(root, folder_in_archive)
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
self.data, self.labels = torch.load(os.path.join(
self.root, self.processed_folder, self.processed_file))
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
download_url(url, root)
extract_archive(archive)
if not os.path.isdir(self._path):
raise RuntimeError(
"Dataset not found. Please use `download=True` to download it."
)
def __getitem__(self, index):
"""
Args:
index (int): Index
walker = walk_files(
self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True
)
self._walker = list(walker)
Returns:
Tuple[torch.Tensor, int]: The output tuple (image, target) where target
is index of the target class.
"""
audio, target = self.data[index], self.labels[index]
def __getitem__(self, n):
fileid = self._walker[n]
item = load_yesno_item(fileid, self._path, self._ext_audio)
waveform = item["waveform"]
if self.transform is not None:
audio = self.transform(audio)
waveform = self.transform(waveform)
item["waveform"] = waveform
label = item["label"]
if self.target_transform is not None:
target = self.target_transform(target)
label = self.target_transform(label)
item["label"] = label
return audio, target
if self.return_dict:
return item
def __len__(self):
return len(self.data)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.processed_file))
def download(self):
"""Download the yesno data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
raw_abs_dir = os.path.join(self.root, self.raw_folder)
processed_abs_dir = os.path.join(self.root, self.processed_folder)
dset_abs_path = os.path.join(
self.root, self.raw_folder, self.dset_path)
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
url = self.url
print('Downloading ' + url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
if not os.path.isfile(file_path):
urllib.request.urlretrieve(url, file_path)
else:
print("Tar file already downloaded")
if not os.path.exists(dset_abs_path):
with tarfile.open(file_path) as zip_f:
zip_f.extractall(raw_abs_dir)
else:
print("Tar file already extracted")
if not self.dev_mode:
os.unlink(file_path)
# process and save as torch files
print('Processing...')
shutil.copyfile(
os.path.join(dset_abs_path, "README"),
os.path.join(processed_abs_dir, "YESNO_README")
)
audios = [x for x in os.listdir(dset_abs_path) if ".wav" in x]
print("Found {} audio files".format(len(audios)))
tensors = []
labels = []
lengths = []
for i, f in enumerate(audios):
full_path = os.path.join(dset_abs_path, f)
sig, sr = torchaudio.load(full_path)
tensors.append(sig)
lengths.append(sig.size(1))
labels.append(os.path.basename(f).split(".", 1)[0].split("_"))
# sort sigs/labels: longest -> shortest
tensors, labels = zip(*[(b, c) for (a, b, c) in sorted(
zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)])
self.max_len = tensors[0].size(1)
torch.save(
(tensors, labels),
os.path.join(
self.root,
self.processed_folder,
self.processed_file
)
)
if not self.dev_mode:
shutil.rmtree(raw_abs_dir, ignore_errors=True)
return item["waveform"], item["label"]
print('Done!')
def __len__(self):
return len(self._walker)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment