Unverified Commit 4d58bc46 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Adding Speech Command Dataset (#437)



* add speechcommand dataset and test

* prepend the full path to each result

* add missing param on docstring in walk_files

* add file to run tests on SpeechCommand Dataset

* reduce logic

* update test on SpeechCommands

* correct the indentation on docstring walk_files

* flake8 compliance

* change tuple type returned. move path split logic in load item.

* typo in name.

* redundant file path.

* filter background noise.
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent 32bae85c
......@@ -3,6 +3,7 @@ import unittest
from torchaudio.datasets.commonvoice import COMMONVOICE
from torchaudio.datasets.librispeech import LIBRISPEECH
from torchaudio.datasets.speechcommands import SPEECHCOMMANDS
from torchaudio.datasets.utils import diskcache_iterator, bg_iterator
from torchaudio.datasets.vctk import VCTK
from torchaudio.datasets.yesno import YESNO
......@@ -52,6 +53,10 @@ class TestDatasets(unittest.TestCase):
data = LJSPEECH(self.path)
data[0]
def test_speechcommands(self):
data = SPEECHCOMMANDS(self.path)
data[0]
if __name__ == "__main__":
unittest.main()
from .commonvoice import COMMONVOICE
from .librispeech import LIBRISPEECH
from .speechcommands import SPEECHCOMMANDS
from .utils import bg_iterator, diskcache_iterator
from .vctk import VCTK
from .yesno import YESNO
......@@ -8,6 +9,7 @@ from .ljspeech import LJSPEECH
__all__ = (
"COMMONVOICE",
"LIBRISPEECH",
"SPEECHCOMMANDS",
"VCTK",
"YESNO",
"LJSPEECH",
......
import os
import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
download_url,
extract_archive,
walk_files
)
FOLDER_IN_ARCHIVE = "SpeechCommands"
URL = "speech_commands_v0.02"
HASH_DIVIDER = "_nohash_"
EXCEPT_FOLDER = "_background_noise_"
def load_speechcommands_item(filepath, path):
relpath = os.path.relpath(filepath, path)
label, filename = os.path.split(relpath)
speaker, _ = os.path.splitext(filename)
speaker_id, utterance_number = speaker.split(HASH_DIVIDER)
utterance_number = int(utterance_number)
# Load audio
waveform, sample_rate = torchaudio.load(filepath)
return waveform, sample_rate, label, speaker_id, utterance_number
class SPEECHCOMMANDS(Dataset):
"""
Create a Dataset for Speech Commands. Each item is a tuple of the form:
waveform, sample_rate, label, speaker_id, utterance_number
"""
def __init__(
self,
root,
url=URL,
folder_in_archive=FOLDER_IN_ARCHIVE,
download=False
):
if url in [
"speech_commands_v0.01",
"speech_commands_v0.02",
]:
base_url = "https://storage.googleapis.com/download.tensorflow.org/data/"
ext_archive = ".tar.gz"
url = os.path.join(base_url, url + ext_archive)
basename = os.path.basename(url)
archive = os.path.join(root, basename)
basename = basename.rsplit(".", 2)[0]
folder_in_archive = os.path.join(folder_in_archive, basename)
self._path = os.path.join(root, folder_in_archive)
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
download_url(url, root)
extract_archive(archive, self._path)
walker = walk_files(self._path, suffix=".wav", prefix=True)
walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker)
self._walker = list(walker)
def __getitem__(self, n):
fileid = self._walker[n]
return load_speechcommands_item(fileid, self._path)
def __len__(self):
return len(self._walker)
......@@ -257,21 +257,23 @@ def walk_files(root, suffix, prefix=False, remove_suffix=False):
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
prefix (bool, optional): If true, prepends the full path to each result, otherwise
only returns the name of the files found
remove_suffix (bool, optional): If true, removes the suffix to each result defined in suffix,
otherwise will return the result as found.
"""
root = os.path.expanduser(root)
for _, _, fn in os.walk(root):
for f in fn:
for dirpath, _, files in os.walk(root):
for f in files:
if f.endswith(suffix):
if remove_suffix:
f = f[: -len(suffix)]
if prefix:
f = os.path.join(root, f)
f = os.path.join(dirpath, f)
yield f
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment