Unverified Commit 8920802d authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

new dataset format with librispeech and commonvoice (#303)

* new dataset format.
* add basic test.
* files for testing.
* serialization using torch.
* add diskcache.
* adding deprecation warnings.
* removing legacy.
* warning about transforms.
* detecting file format using reader.
parent b8203182
client_id path sentence up_votes down_votes age gender accent
00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 common_voice_tt_00000000.mp3 test. 1 0 thirties female
import os
import unittest
from torchaudio.datasets.commonvoice import COMMONVOICE
from torchaudio.datasets.librispeech import LIBRISPEECH
from torchaudio.datasets.utils import DiskCache
from torchaudio.datasets.vctk import VCTK
from torchaudio.datasets.yesno import YESNO
class TestDatasets(unittest.TestCase):
path = "assets"
def test_yesno(self):
data = YESNO(self.path, return_dict=True)
data[0]
def test_vctk(self):
data = VCTK(self.path, return_dict=True)
data[0]
def test_librispeech(self):
data = LIBRISPEECH(self.path, "dev-clean")
data[0]
def test_commonvoice(self):
path = os.path.join(self.path, "commonvoice")
data = COMMONVOICE(path, "train.tsv", "tatar")
data[0]
def test_commonvoice_diskcache(self):
path = os.path.join(self.path, "commonvoice")
data = COMMONVOICE(path, "train.tsv", "tatar")
data = DiskCache(data)
# Save
data[0]
# Load
data[0]
if __name__ == "__main__":
unittest.main()
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import torch
import torchaudio
import unittest
import common_utils
import torchaudio.datasets.vctk as vctk
class TestVCTK(unittest.TestCase):
def setUp(self):
self.test_dirpath, self.test_dir = common_utils.create_temp_assets_dir()
def get_full_path(self, file):
return os.path.join(self.test_dirpath, 'assets', file)
def test_is_audio_file(self):
self.assertTrue(vctk.is_audio_file('foo.wav'))
self.assertTrue(vctk.is_audio_file('foo.WAV'))
self.assertFalse(vctk.is_audio_file('foo.bar'))
def test_make_manifest(self):
audios = vctk.make_manifest(self.test_dirpath)
files = ['kaldi_file.wav', 'kaldi_file_8000.wav',
'sinewave.wav', 'steam-train-whistle-daniel_simon.mp3',
'dtmf_30s_stereo.mp3', 'whitenoise_1min.mp3', 'whitenoise.mp3']
files = [self.get_full_path(file) for file in files]
files.sort()
audios.sort()
self.assertEqual(files, audios, msg='files %s did not match audios %s' % (files, audios))
def test_read_audio_downsample_false(self):
file = self.get_full_path('kaldi_file.wav')
s, sr = vctk.read_audio(file, downsample=False)
self.assertEqual(sr, 16000, msg='incorrect sample rate %d' % (sr))
self.assertEqual(s.shape, (1, 20), msg='incorrect shape %s' % (str(s.shape)))
def test_read_audio_downsample_true(self):
file = self.get_full_path('kaldi_file.wav')
s, sr = vctk.read_audio(file, downsample=True)
self.assertEqual(sr, 16000, msg='incorrect sample rate %d' % (sr))
self.assertEqual(s.shape, (1, 20), msg='incorrect shape %s' % (str(s.shape)))
def test_load_txts(self):
utterences = vctk.load_txts(self.test_dirpath)
expected_utterances = {'file2': 'word5 word6\n', 'file1': 'word1 word2\n'}
self.assertEqual(utterences, expected_utterances,
msg='%s did not match %s' % (utterences, expected_utterances))
def test_vctk(self):
# TODO somehow test download=True, the dataset is too big download ~10 GB for
# each test so need a way to mock it
self.assertRaises(RuntimeError, vctk.VCTK, self.test_dirpath, download=False)
if __name__ == '__main__':
unittest.main()
from .yesno import YESNO from .commonvoice import COMMONVOICE
from .librispeech import LIBRISPEECH
from .vctk import VCTK from .vctk import VCTK
from .yesno import YESNO
from .utils import DiskCache
__all__ = ('YESNO', 'VCTK') __all__ = ("COMMONVOICE", "LIBRISPEECH", "VCTK", "YESNO", "DiskCache")
import os
import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader
# Default TSV should be one of
# dev.tsv
# invalidated.tsv
# other.tsv
# test.tsv
# train.tsv
# validated.tsv
URL = "english"
TSV = "train.tsv"
def load_commonvoice_item(line, header, path, folder_audio):
fileid = line[1]
filename = os.path.join(path, folder_audio, fileid)
waveform, sample_rate = torchaudio.load(filename)
dic = dict(zip(header, line))
dic["waveform"] = waveform
dic["sample_rate"] = sample_rate
return dic
class COMMONVOICE(Dataset):
_ext_txt = ".txt"
_ext_audio = ".mp3"
_folder_audio = "clips"
def __init__(self, root, tsv=TSV, url=URL, download=False):
languages = {
"tatar": "tt",
"english": "en",
"german": "de",
"french": "fr",
"welsh": "cy",
"breton": "br",
"chuvash": "cv",
"turkish": "tr",
"kyrgyz": "ky",
"irish": "ga-IE",
"kabyle": "kab",
"catalan": "ca",
"taiwanese": "zh-TW",
"slovenian": "sl",
"italian": "it",
"dutch": "nl",
"hakha chin": "cnh",
"esperanto": "eo",
"estonian": "et",
"persian": "fa",
"basque": "eu",
"spanish": "es",
"chinese": "zh-CN",
"mongolian": "mn",
"sakha": "sah",
"dhivehi": "dv",
"kinyarwanda": "rw",
"swedish": "sv-SE",
"russian": "ru",
}
if url is languages:
ext_archive = ".tar.gz"
language = languages[url]
base_url = (
"https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4"
+ ".s3.amazonaws.com/cv-corpus-3/"
)
url = base_url + language + ext_archive
archive = os.path.basename(url)
archive = os.path.join(root, archive)
self._path = root
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
download_url(url, root)
extract_archive(archive)
self._tsv = os.path.join(root, tsv)
with open(self._tsv, "r") as tsv:
walker = unicode_csv_reader(tsv, delimiter="\t")
self._header = next(walker)
self._walker = list(walker)
def __getitem__(self, n):
line = self._walker[n]
return load_commonvoice_item(line, self._header, self._path, self._folder_audio)
def __len__(self):
return len(self._walker)
import os
import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
download_url,
extract_archive,
unicode_csv_reader,
walk_files,
)
URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriSpeech"
def load_librispeech_item(fileid, path, ext_audio, ext_txt):
speaker, chapter, utterance = fileid.split("-")
file_text = speaker + "-" + chapter + ext_txt
file_text = os.path.join(path, speaker, chapter, file_text)
file_audio = speaker + "-" + chapter + "-" + utterance + ext_audio
file_audio = os.path.join(path, speaker, chapter, file_audio)
# Load audio
waveform, sample_rate = torchaudio.load(file_audio)
# Load text
for line in open(file_text):
fileid_text, content = line.strip().split(" ", 1)
if file_audio == fileid_text:
break
else:
# Translation not found
raise ValueError
return {
"speaker_id": speaker,
"chapter_id": chapter,
"utterance_id": utterance,
"utterance": content,
"waveform": waveform,
"sample_rate": sample_rate,
}
class LIBRISPEECH(Dataset):
_ext_txt = ".trans.txt"
_ext_audio = ".flac"
def __init__(
self, root, url=URL, folder_in_archive=FOLDER_IN_ARCHIVE, download=False
):
if url in [
"dev-clean",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
]:
ext_archive = ".tar.gz"
base_url = "http://www.openslr.org/resources/12/"
url = os.path.join(base_url, url + ext_archive)
basename = os.path.basename(url)
archive = os.path.join(root, basename)
basename = basename.split(".")[0]
folder_in_archive = os.path.join(folder_in_archive, basename)
self._path = os.path.join(root, folder_in_archive)
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
download_url(url, root)
extract_archive(archive)
walker = walk_files(
self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True
)
self._walker = list(walker)
def __getitem__(self, n):
fileid = self._walker[n]
return load_librispeech_item(fileid, self._path, self._ext_audio, self._ext_txt)
def __len__(self):
return len(self._walker)
import csv
import errno
import gzip
import hashlib
import logging
import os
import sys
import tarfile
import zipfile
import six
import torch
import torchaudio
from torch.utils.data import Dataset
from torch.utils.model_zoo import tqdm
def unicode_csv_reader(unicode_csv_data, **kwargs):
r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper.
Borrowed and slightly modified from the Python docs:
https://docs.python.org/2/library/csv.html#csv-examples
Arguments:
unicode_csv_data: unicode csv data (see example below)
Examples:
>>> from torchaudio.datasets.utils import unicode_csv_reader
>>> import io
>>> with io.open(data_path, encoding="utf8") as f:
>>> reader = unicode_csv_reader(f)
"""
# Fix field larger than field limit error
maxInt = sys.maxsize
while True:
# decrease the maxInt value by factor 10
# as long as the OverflowError occurs.
try:
csv.field_size_limit(maxInt)
break
except OverflowError:
maxInt = int(maxInt / 10)
csv.field_size_limit(maxInt)
if six.PY2:
# csv.py doesn't do Unicode; encode temporarily as UTF-8:
csv_reader = csv.reader(utf_8_encoder(unicode_csv_data), **kwargs)
for row in csv_reader:
# decode UTF-8 back to Unicode, cell by cell:
yield [cell.decode("utf-8") for cell in row]
else:
for line in csv.reader(unicode_csv_data, **kwargs):
yield line
def gen_bar_updater():
pbar = tqdm(total=None)
def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)
return bar_update
def makedir_exist_ok(dirpath):
"""
Python2 support for os.makedirs(.., exist_ok=True)
"""
try:
os.makedirs(dirpath)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
def download_url(url, root, filename=None, md5=None):
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
from six.moves import urllib
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
makedir_exist_ok(root)
# downloads file
if os.path.isfile(fpath):
print("Using downloaded file: " + fpath)
else:
try:
print("Downloading " + url + " to " + fpath)
urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
except (urllib.error.URLError, IOError) as e:
if url[:5] == "https":
url = url.replace("https:", "http:")
print(
"Failed download. Trying https -> http instead."
" Downloading " + url + " to " + fpath
)
urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
else:
raise e
def extract_archive(from_path, to_path=None, overwrite=False):
"""Extract archive.
Arguments:
from_path: the path of the archive.
to_path: the root path of the extraced files (directory of from_path)
overwrite: overwrite existing files (False)
Returns:
List of paths to extracted files even if not overwritten.
Examples:
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> from_path = './validation.tar.gz'
>>> to_path = './'
>>> torchaudio.datasets.utils.download_from_url(url, from_path)
>>> torchaudio.datasets.utils.extract_archive(from_path, to_path)
"""
if to_path is None:
to_path = os.path.dirname(from_path)
try:
with tarfile.open(from_path, "r") as tar:
logging.info("Opened tar file {}.".format(from_path))
files = []
for file_ in tar:
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
logging.info("{} already extracted.".format(file_path))
if not overwrite:
continue
tar.extract(file_, to_path)
return files
except tarfile.ReadError:
pass
try:
with zipfile.ZipFile(from_path, "r") as zfile:
logging.info("Opened zip file {}.".format(from_path))
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
logging.info("{} already extracted.".format(file_path))
if not overwrite:
continue
zfile.extract(file_, to_path)
return files
except zipfile.BadZipFile:
pass
raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.")
def walk_files(root, suffix, prefix=False, remove_suffix=False):
"""List recursively all files ending with a suffix at a given root
Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
for _, _, fn in os.walk(root):
for f in fn:
if f.endswith(suffix):
if remove_suffix:
f = f[: -len(suffix)]
if prefix:
f = os.path.join(root, f)
yield f
class DiskCache(Dataset):
"""
Wrap a dataset so that, whenever a new item is returned, it is saved to disk.
"""
def __init__(self, dataset, location=".cached"):
self.dataset = dataset
self.location = location
self._id = id(self)
self._cache = [None] * len(dataset)
def __getitem__(self, n):
if self._cache[n]:
f = self._cache[n]
return torch.load(f)
f = str(self._id) + "-" + str(n)
f = os.path.join(self.location, f)
item = self.dataset[n]
self._cache[n] = f
makedir_exist_ok(self.location)
torch.save(item, f)
return item
def __len__(self):
return len(self.dataset)
from __future__ import absolute_import, division, print_function, unicode_literals
import torch.utils.data as data
import os import os
import os.path import warnings
import shutil
import errno
import torch
import torchaudio
AUDIO_EXTENSIONS = [
'.wav', '.mp3', '.flac', '.sph', '.ogg', '.opus',
'.WAV', '.MP3', '.FLAC', '.SPH', '.OGG', '.OPUS',
]
def is_audio_file(filename): import torchaudio
return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS) from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive, walk_files
URL = "http://homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"
FOLDER_IN_ARCHIVE = "VCTK-Corpus"
def make_manifest(dir):
audios = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)): def load_vctk_item(
for fname in fnames: fileid, path, ext_audio, ext_txt, folder_audio, folder_txt, downsample=False
if is_audio_file(fname): ):
path = os.path.join(root, fname) speaker, utterance = fileid.split("_")
item = path
audios.append(item)
return audios
# Read text
file_txt = os.path.join(path, folder_txt, speaker, fileid + ext_txt)
with open(file_txt) as file_text:
content = file_text.readlines()[0]
def read_audio(fp, downsample=True): # Read wav
file_audio = os.path.join(path, folder_audio, speaker, fileid + ext_audio)
if downsample: if downsample:
# Legacy
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(fp) E.set_input_file(file_audio)
E.append_effect_to_chain("gain", ["-h"]) E.append_effect_to_chain("gain", ["-h"])
E.append_effect_to_chain("channels", [1]) E.append_effect_to_chain("channels", [1])
E.append_effect_to_chain("rate", [16000]) E.append_effect_to_chain("rate", [16000])
E.append_effect_to_chain("gain", ["-rh"]) E.append_effect_to_chain("gain", ["-rh"])
E.append_effect_to_chain("dither", ["-s"]) E.append_effect_to_chain("dither", ["-s"])
sig, sr = E.sox_build_flow_effects() waveform, sample_rate = E.sox_build_flow_effects()
else: else:
sig, sr = torchaudio.load(fp) waveform, sample_rate = torchaudio.load(file_audio)
sig = sig.contiguous()
return sig, sr return {
"speaker_id": speaker,
"utterance_id": utterance,
def load_txts(dir): "utterance": content,
"""Create a dictionary with all the text of the audio transcriptions.""" "waveform": waveform,
utterences = dict() "sample_rate": sample_rate,
dir = os.path.expanduser(dir) }
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d): class VCTK(Dataset):
continue
_folder_txt = "txt"
for root, _, fnames in sorted(os.walk(d)): _folder_audio = "wav48"
for fname in fnames: _ext_txt = ".txt"
if fname.endswith(".txt"): _ext_audio = ".wav"
with open(os.path.join(root, fname), "r") as f:
fname_no_ext = os.path.basename( def __init__(
fname).rsplit(".", 1)[0] self,
utterences[fname_no_ext] = f.readline() root,
return utterences url=URL,
folder_in_archive=FOLDER_IN_ARCHIVE,
download=False,
class VCTK(data.Dataset): downsample=False,
r"""`VCTK <http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html>`_ Dataset. transform=None,
`alternate url <http://datashare.is.ed.ac.uk/handle/10283/2651>`_ target_transform=None,
return_dict=False,
Args: ):
root (str): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist. if not return_dict:
downsample (bool, optional): Whether to downsample the signal (Default: ``True``) warnings.warn(
transform (Callable, optional): A function/transform that takes in an raw audio "In the next version, the item returned will be a dictionary. "
and returns a transformed version. E.g, ``transforms.Spectrogram``. (Default: ``None``) "Please use `return_dict=True` to enable this behavior now, "
target_transform (callable, optional): A function/transform that takes in the "and suppress this warning.",
target and transforms it. (Default: ``None``) DeprecationWarning,
download (bool, optional): If true, downloads the dataset from the internet and )
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. (Default: ``True``) if downsample:
dev_mode(bool, optional): If true, clean up is not performed on downloaded warnings.warn(
files. Useful to keep raw audio and transcriptions. (Default: ``False``) "In the next version, transforms will not be part of the dataset. "
""" "Please use `downsample=False` to enable this behavior now, ",
raw_folder = 'vctk/raw' "and suppress this warning.",
processed_folder = 'vctk/processed' DeprecationWarning,
url = 'http://homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz' )
dset_path = 'VCTK-Corpus'
if transform is not None or target_transform is not None:
def __init__(self, root, downsample=True, transform=None, target_transform=None, download=False, dev_mode=False): warnings.warn(
self.root = os.path.expanduser(root) "In the next version, transforms will not be part of the dataset. "
"Please remove the option `transform=True` and "
"`target_transform=True` to suppress this warning.",
DeprecationWarning,
)
self.downsample = downsample self.downsample = downsample
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.dev_mode = dev_mode self.return_dict = return_dict
self.data = []
self.labels = [] archive = os.path.basename(url)
self.chunk_size = 1000 archive = os.path.join(root, archive)
self.num_samples = 0 self._path = os.path.join(root, folder_in_archive)
self.max_len = 0
self.cached_pt = 0
if download: if download:
self.download() if not os.path.isdir(self._path):
if not os.path.isfile(archive):
if not self._check_exists(): download_url(url, root)
raise RuntimeError('Dataset not found.' + extract_archive(archive)
' You can use download=True to download it')
self._read_info() if not os.path.isdir(self._path):
self.data, self.labels = torch.load(os.path.join( raise RuntimeError(
self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt))) "Dataset not found. Please use `download=True` to download it."
)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
Tuple[torch.Tensor, int]: The output tuple (image, target) where target
is index of the target class.
"""
if self.cached_pt != index // self.chunk_size:
self.cached_pt = int(index // self.chunk_size)
self.data, self.labels = torch.load(os.path.join(
self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt)))
index = index % self.chunk_size
audio, target = self.data[index], self.labels[index]
walker = walk_files(
self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True
)
self._walker = list(walker)
def __getitem__(self, n):
fileid = self._walker[n]
item = load_vctk_item(
fileid,
self._path,
self._ext_audio,
self._ext_txt,
self._folder_audio,
self._folder_txt,
)
# Legacy
waveform = item["waveform"]
if self.transform is not None: if self.transform is not None:
audio = self.transform(audio) waveform = self.transform(waveform)
item["waveform"] = waveform
# Legacy
utterance = item["utterance"]
if self.target_transform is not None: if self.target_transform is not None:
target = self.target_transform(target) utterance = self.target_transform(utterance)
item["utterance"] = utterance
if self.return_dict:
return item
return audio, target # Legacy
return item["waveform"], item["utterance"]
def __len__(self): def __len__(self):
return self.num_samples return len(self._walker)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, "vctk_info.txt"))
def _write_info(self, num_items):
info_path = os.path.join(
self.root, self.processed_folder, "vctk_info.txt")
with open(info_path, "w") as f:
f.write("num_samples,{}\n".format(num_items))
f.write("max_len,{}\n".format(self.max_len))
def _read_info(self):
info_path = os.path.join(
self.root, self.processed_folder, "vctk_info.txt")
with open(info_path, "r") as f:
self.num_samples = int(f.readline().split(",")[1])
self.max_len = int(f.readline().split(",")[1])
def download(self):
"""Download the VCTK data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
raw_abs_dir = os.path.join(self.root, self.raw_folder)
processed_abs_dir = os.path.join(self.root, self.processed_folder)
dset_abs_path = os.path.join(
self.root, self.raw_folder, self.dset_path)
# download files
try:
os.makedirs(os.path.join(self.root, self.processed_folder))
os.makedirs(os.path.join(self.root, self.raw_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
url = self.url
print('Downloading ' + url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
if not os.path.isfile(file_path):
urllib.request.urlretrieve(url, file_path)
if not os.path.exists(dset_abs_path):
with tarfile.open(file_path) as zip_f:
zip_f.extractall(raw_abs_dir)
else:
print("Using existing raw folder")
if not self.dev_mode:
os.unlink(file_path)
# process and save as torch files
torchaudio.initialize_sox()
print('Processing...')
shutil.copyfile(
os.path.join(dset_abs_path, "COPYING"),
os.path.join(processed_abs_dir, "VCTK_COPYING")
)
audios = make_manifest(dset_abs_path)
utterences = load_txts(dset_abs_path)
self.max_len = 0
print("Found {} audio files and {} utterences".format(
len(audios), len(utterences)))
for n in range(len(audios) // self.chunk_size + 1):
tensors = []
labels = []
lengths = []
st_idx = n * self.chunk_size
end_idx = st_idx + self.chunk_size
for i, f in enumerate(audios[st_idx:end_idx]):
txt_dir = os.path.dirname(f).replace("wav48", "txt")
if os.path.exists(txt_dir):
f_rel_no_ext = os.path.basename(f).rsplit(".", 1)[0]
sig = read_audio(f, downsample=self.downsample)[0]
tensors.append(sig)
lengths.append(sig.size(1))
labels.append(utterences[f_rel_no_ext])
self.max_len = sig.size(1) if sig.size(
1) > self.max_len else self.max_len
# sort sigs/labels: longest -> shortest
tensors, labels = zip(*[(b, c) for (a, b, c) in sorted(
zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)])
data = (tensors, labels)
torch.save(
data,
os.path.join(
self.root,
self.processed_folder,
"vctk_{:04d}.pt".format(n)
)
)
self._write_info((n * self.chunk_size) + i + 1)
if not self.dev_mode:
shutil.rmtree(raw_abs_dir, ignore_errors=True)
torchaudio.shutdown_sox()
print('Done!')
from __future__ import absolute_import, division, print_function, unicode_literals
import torch.utils.data as data
import os import os
import os.path import warnings
import shutil
import errno
import torch
import torchaudio import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive, walk_files
URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz"
FOLDER_IN_ARCHIVE = "waves_yesno"
def load_yesno_item(fileid, path, ext_audio):
# Read label
label = fileid.split("_")
# Read wav
file_audio = os.path.join(path, fileid + ext_audio)
waveform, sample_rate = torchaudio.load(file_audio)
return {"label": label, "waveform": waveform, "sample_rate": sample_rate}
class YESNO(Dataset):
_ext_audio = ".wav"
def __init__(
self,
root,
url=URL,
folder_in_archive=FOLDER_IN_ARCHIVE,
download=False,
transform=None,
target_transform=None,
return_dict=False,
):
if not return_dict:
warnings.warn(
"In the next version, the item returned will be a dictionary. "
"Please use `return_dict=True` to enable this behavior now, "
"and suppress this warning.",
DeprecationWarning,
)
if transform is not None or target_transform is not None:
warnings.warn(
"In the next version, transforms will not be part of the dataset. "
"Please remove the option `transform=True` and "
"`target_transform=True` to suppress this warning.",
DeprecationWarning,
)
class YESNO(data.Dataset):
r"""`YesNo Hebrew <http://www.openslr.org/1/>`_ Dataset.
Args:
root (str): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
transform (Callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.Spectrogram``. (
Default: ``None``)
target_transform (Callable, optional): A function/transform that takes in the
target and transforms it. (Default: ``None``)
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. (Default: ``False``)
dev_mode(bool, optional): If true, clean up is not performed on downloaded
files. Useful to keep raw audio and transcriptions. (Default: ``False``)
"""
raw_folder = 'yesno/raw'
processed_folder = 'yesno/processed'
url = 'http://www.openslr.org/resources/1/waves_yesno.tar.gz'
dset_path = 'waves_yesno'
processed_file = 'yesno.pt'
def __init__(self, root, transform=None, target_transform=None, download=False, dev_mode=False):
self.root = os.path.expanduser(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.dev_mode = dev_mode self.return_dict = return_dict
self.data = []
self.labels = []
self.num_samples = 0
self.max_len = 0
if download: archive = os.path.basename(url)
self.download() archive = os.path.join(root, archive)
self._path = os.path.join(root, folder_in_archive)
if not self._check_exists(): if download:
raise RuntimeError('Dataset not found.' + if not os.path.isdir(self._path):
' You can use download=True to download it') if not os.path.isfile(archive):
self.data, self.labels = torch.load(os.path.join( download_url(url, root)
self.root, self.processed_folder, self.processed_file)) extract_archive(archive)
if not os.path.isdir(self._path):
raise RuntimeError(
"Dataset not found. Please use `download=True` to download it."
)
def __getitem__(self, index): walker = walk_files(
""" self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True
Args: )
index (int): Index self._walker = list(walker)
Returns: def __getitem__(self, n):
Tuple[torch.Tensor, int]: The output tuple (image, target) where target fileid = self._walker[n]
is index of the target class. item = load_yesno_item(fileid, self._path, self._ext_audio)
"""
audio, target = self.data[index], self.labels[index]
waveform = item["waveform"]
if self.transform is not None: if self.transform is not None:
audio = self.transform(audio) waveform = self.transform(waveform)
item["waveform"] = waveform
label = item["label"]
if self.target_transform is not None: if self.target_transform is not None:
target = self.target_transform(target) label = self.target_transform(label)
item["label"] = label
return audio, target if self.return_dict:
return item
def __len__(self): return item["waveform"], item["label"]
return len(self.data)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.processed_file))
def download(self):
"""Download the yesno data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
raw_abs_dir = os.path.join(self.root, self.raw_folder)
processed_abs_dir = os.path.join(self.root, self.processed_folder)
dset_abs_path = os.path.join(
self.root, self.raw_folder, self.dset_path)
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
url = self.url
print('Downloading ' + url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
if not os.path.isfile(file_path):
urllib.request.urlretrieve(url, file_path)
else:
print("Tar file already downloaded")
if not os.path.exists(dset_abs_path):
with tarfile.open(file_path) as zip_f:
zip_f.extractall(raw_abs_dir)
else:
print("Tar file already extracted")
if not self.dev_mode:
os.unlink(file_path)
# process and save as torch files
print('Processing...')
shutil.copyfile(
os.path.join(dset_abs_path, "README"),
os.path.join(processed_abs_dir, "YESNO_README")
)
audios = [x for x in os.listdir(dset_abs_path) if ".wav" in x]
print("Found {} audio files".format(len(audios)))
tensors = []
labels = []
lengths = []
for i, f in enumerate(audios):
full_path = os.path.join(dset_abs_path, f)
sig, sr = torchaudio.load(full_path)
tensors.append(sig)
lengths.append(sig.size(1))
labels.append(os.path.basename(f).split(".", 1)[0].split("_"))
# sort sigs/labels: longest -> shortest
tensors, labels = zip(*[(b, c) for (a, b, c) in sorted(
zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)])
self.max_len = tensors[0].size(1)
torch.save(
(tensors, labels),
os.path.join(
self.root,
self.processed_folder,
self.processed_file
)
)
if not self.dev_mode:
shutil.rmtree(raw_abs_dir, ignore_errors=True)
print('Done!') def __len__(self):
return len(self._walker)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment