Unverified Commit 4d58bc46 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Adding Speech Command Dataset (#437)



* add speechcommand dataset and test

* prepend the full path to each result

* add missing param on docstring in walk_files

* add file to run tests on SpeechCommand Dataset

* reduce logic

* update test on SpeechCommands

* correct the indentation on docstring walk_files

* flake8 compliance

* change tuple type returned. move path split logic in load item.

* typo in name.

* redundant file path.

* filter background noise.
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent 32bae85c
...@@ -3,6 +3,7 @@ import unittest ...@@ -3,6 +3,7 @@ import unittest
from torchaudio.datasets.commonvoice import COMMONVOICE from torchaudio.datasets.commonvoice import COMMONVOICE
from torchaudio.datasets.librispeech import LIBRISPEECH from torchaudio.datasets.librispeech import LIBRISPEECH
from torchaudio.datasets.speechcommands import SPEECHCOMMANDS
from torchaudio.datasets.utils import diskcache_iterator, bg_iterator from torchaudio.datasets.utils import diskcache_iterator, bg_iterator
from torchaudio.datasets.vctk import VCTK from torchaudio.datasets.vctk import VCTK
from torchaudio.datasets.yesno import YESNO from torchaudio.datasets.yesno import YESNO
...@@ -52,6 +53,10 @@ class TestDatasets(unittest.TestCase): ...@@ -52,6 +53,10 @@ class TestDatasets(unittest.TestCase):
data = LJSPEECH(self.path) data = LJSPEECH(self.path)
data[0] data[0]
def test_speechcommands(self):
data = SPEECHCOMMANDS(self.path)
data[0]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
from .commonvoice import COMMONVOICE from .commonvoice import COMMONVOICE
from .librispeech import LIBRISPEECH from .librispeech import LIBRISPEECH
from .speechcommands import SPEECHCOMMANDS
from .utils import bg_iterator, diskcache_iterator from .utils import bg_iterator, diskcache_iterator
from .vctk import VCTK from .vctk import VCTK
from .yesno import YESNO from .yesno import YESNO
...@@ -8,6 +9,7 @@ from .ljspeech import LJSPEECH ...@@ -8,6 +9,7 @@ from .ljspeech import LJSPEECH
__all__ = ( __all__ = (
"COMMONVOICE", "COMMONVOICE",
"LIBRISPEECH", "LIBRISPEECH",
"SPEECHCOMMANDS",
"VCTK", "VCTK",
"YESNO", "YESNO",
"LJSPEECH", "LJSPEECH",
......
import os
import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
download_url,
extract_archive,
walk_files
)
FOLDER_IN_ARCHIVE = "SpeechCommands"
URL = "speech_commands_v0.02"
HASH_DIVIDER = "_nohash_"
EXCEPT_FOLDER = "_background_noise_"
def load_speechcommands_item(filepath, path):
relpath = os.path.relpath(filepath, path)
label, filename = os.path.split(relpath)
speaker, _ = os.path.splitext(filename)
speaker_id, utterance_number = speaker.split(HASH_DIVIDER)
utterance_number = int(utterance_number)
# Load audio
waveform, sample_rate = torchaudio.load(filepath)
return waveform, sample_rate, label, speaker_id, utterance_number
class SPEECHCOMMANDS(Dataset):
"""
Create a Dataset for Speech Commands. Each item is a tuple of the form:
waveform, sample_rate, label, speaker_id, utterance_number
"""
def __init__(
self,
root,
url=URL,
folder_in_archive=FOLDER_IN_ARCHIVE,
download=False
):
if url in [
"speech_commands_v0.01",
"speech_commands_v0.02",
]:
base_url = "https://storage.googleapis.com/download.tensorflow.org/data/"
ext_archive = ".tar.gz"
url = os.path.join(base_url, url + ext_archive)
basename = os.path.basename(url)
archive = os.path.join(root, basename)
basename = basename.rsplit(".", 2)[0]
folder_in_archive = os.path.join(folder_in_archive, basename)
self._path = os.path.join(root, folder_in_archive)
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
download_url(url, root)
extract_archive(archive, self._path)
walker = walk_files(self._path, suffix=".wav", prefix=True)
walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker)
self._walker = list(walker)
def __getitem__(self, n):
fileid = self._walker[n]
return load_speechcommands_item(fileid, self._path)
def __len__(self):
return len(self._walker)
...@@ -257,21 +257,23 @@ def walk_files(root, suffix, prefix=False, remove_suffix=False): ...@@ -257,21 +257,23 @@ def walk_files(root, suffix, prefix=False, remove_suffix=False):
root (str): Path to directory whose folders need to be listed root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise prefix (bool, optional): If true, prepends the full path to each result, otherwise
only returns the name of the files found only returns the name of the files found
remove_suffix (bool, optional): If true, removes the suffix to each result defined in suffix,
otherwise will return the result as found.
""" """
root = os.path.expanduser(root) root = os.path.expanduser(root)
for _, _, fn in os.walk(root): for dirpath, _, files in os.walk(root):
for f in fn: for f in files:
if f.endswith(suffix): if f.endswith(suffix):
if remove_suffix: if remove_suffix:
f = f[: -len(suffix)] f = f[: -len(suffix)]
if prefix: if prefix:
f = os.path.join(root, f) f = os.path.join(dirpath, f)
yield f yield f
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment