Commit 4a65b050 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add Speech Commands metadata function (#2687)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2687

Reviewed By: mthrok

Differential Revision: D39647596

Pulled By: carolineechen

fbshipit-source-id: 8ff874fc1e828130f6754e83ce1f702ca13dfac0
parent ad15bc71
...@@ -2,11 +2,10 @@ import os ...@@ -2,11 +2,10 @@ import os
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import Tuple, Union
import torchaudio
from torch import Tensor from torch import Tensor
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive from torchaudio.datasets.utils import _load_waveform, extract_archive
URL = "train-clean-100" URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriSpeech" FOLDER_IN_ARCHIVE = "LibriSpeech"
...@@ -133,13 +132,6 @@ class LIBRISPEECH(Dataset): ...@@ -133,13 +132,6 @@ class LIBRISPEECH(Dataset):
fileid = self._walker[n] fileid = self._walker[n]
return _get_librispeech_metadata(fileid, self._archive, self._url, self._ext_audio, self._ext_txt) return _get_librispeech_metadata(fileid, self._archive, self._url, self._ext_audio, self._ext_txt)
def _load_waveform(self, path: str):
path = os.path.join(self._archive, path)
waveform, sample_rate = torchaudio.load(path)
if sample_rate != SAMPLE_RATE:
raise ValueError(f"sample rate should be 16000 (16kHz), but got {sample_rate}.")
return waveform
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
...@@ -151,7 +143,7 @@ class LIBRISPEECH(Dataset): ...@@ -151,7 +143,7 @@ class LIBRISPEECH(Dataset):
``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)`` ``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
""" """
metadata = self.get_metadata(n) metadata = self.get_metadata(n)
waveform = self._load_waveform(metadata[0]) waveform = _load_waveform(self._archive, metadata[0], metadata[1])
return (waveform,) + metadata[1:] return (waveform,) + metadata[1:]
def __len__(self) -> int: def __len__(self) -> int:
......
...@@ -2,16 +2,16 @@ import os ...@@ -2,16 +2,16 @@ import os
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torchaudio
from torch import Tensor from torch import Tensor
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive from torchaudio.datasets.utils import _load_waveform, extract_archive
FOLDER_IN_ARCHIVE = "SpeechCommands" FOLDER_IN_ARCHIVE = "SpeechCommands"
URL = "speech_commands_v0.02" URL = "speech_commands_v0.02"
HASH_DIVIDER = "_nohash_" HASH_DIVIDER = "_nohash_"
EXCEPT_FOLDER = "_background_noise_" EXCEPT_FOLDER = "_background_noise_"
SAMPLE_RATE = 16000
_CHECKSUMS = { _CHECKSUMS = {
"https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz": "743935421bb51cccdb6bdd152e04c5c70274e935c82119ad7faeec31780d811d", # noqa: E501 "https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz": "743935421bb51cccdb6bdd152e04c5c70274e935c82119ad7faeec31780d811d", # noqa: E501
"https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz": "af14739ee7dc311471de98f5f9d2c9191b18aedfe957f4a6ff791c709868ff58", # noqa: E501 "https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz": "af14739ee7dc311471de98f5f9d2c9191b18aedfe957f4a6ff791c709868ff58", # noqa: E501
...@@ -27,9 +27,10 @@ def _load_list(root, *filenames): ...@@ -27,9 +27,10 @@ def _load_list(root, *filenames):
return output return output
def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str, str, int]: def _get_speechcommands_metadata(filepath: str, path: str) -> Tuple[str, int, str, str, int]:
relpath = os.path.relpath(filepath, path) relpath = os.path.relpath(filepath, path)
label, filename = os.path.split(relpath) reldir, filename = os.path.split(relpath)
_, label = os.path.split(reldir)
# Besides the officially supported split method for datasets defined by "validation_list.txt" # Besides the officially supported split method for datasets defined by "validation_list.txt"
# and "testing_list.txt" over "speech_commands_v0.0x.tar.gz" archives, an alternative split # and "testing_list.txt" over "speech_commands_v0.0x.tar.gz" archives, an alternative split
# method referred to in paragraph 2-3 of Section 7.1, references 13 and 14 of the original # method referred to in paragraph 2-3 of Section 7.1, references 13 and 14 of the original
...@@ -43,9 +44,7 @@ def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str ...@@ -43,9 +44,7 @@ def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str
speaker_id, utterance_number = speaker.split(HASH_DIVIDER) speaker_id, utterance_number = speaker.split(HASH_DIVIDER)
utterance_number = int(utterance_number) utterance_number = int(utterance_number)
# Load audio return relpath, SAMPLE_RATE, label, speaker_id, utterance_number
waveform, sample_rate = torchaudio.load(filepath)
return waveform, sample_rate, label, speaker_id, utterance_number
class SPEECHCOMMANDS(Dataset): class SPEECHCOMMANDS(Dataset):
...@@ -93,6 +92,7 @@ class SPEECHCOMMANDS(Dataset): ...@@ -93,6 +92,7 @@ class SPEECHCOMMANDS(Dataset):
# Get string representation of 'root' in case Path object is passed # Get string representation of 'root' in case Path object is passed
root = os.fspath(root) root = os.fspath(root)
self._archive = os.path.join(root, folder_in_archive)
basename = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(root, basename) archive = os.path.join(root, basename)
...@@ -131,6 +131,20 @@ class SPEECHCOMMANDS(Dataset): ...@@ -131,6 +131,20 @@ class SPEECHCOMMANDS(Dataset):
walker = sorted(str(p) for p in Path(self._path).glob("*/*.wav")) walker = sorted(str(p) for p in Path(self._path).glob("*/*.wav"))
self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w] self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w]
def get_metadata(self, n: int) -> Tuple[str, int, str, str, int]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample to be loaded
Returns:
(str, int, str, str, int):
``(filepath, sample_rate, label, speaker_id, utterance_number)``
"""
fileid = self._walker[n]
return _get_speechcommands_metadata(fileid, self._archive)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
...@@ -141,8 +155,9 @@ class SPEECHCOMMANDS(Dataset): ...@@ -141,8 +155,9 @@ class SPEECHCOMMANDS(Dataset):
(Tensor, int, str, str, int): (Tensor, int, str, str, int):
``(waveform, sample_rate, label, speaker_id, utterance_number)`` ``(waveform, sample_rate, label, speaker_id, utterance_number)``
""" """
fileid = self._walker[n] metadata = self.get_metadata(n)
return load_speechcommands_item(fileid, self._path) waveform = _load_waveform(self._archive, metadata[0], metadata[1])
return (waveform,) + metadata[1:]
def __len__(self) -> int: def __len__(self) -> int:
return len(self._walker) return len(self._walker)
...@@ -8,6 +8,8 @@ import warnings ...@@ -8,6 +8,8 @@ import warnings
import zipfile import zipfile
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
import torchaudio
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
...@@ -189,3 +191,15 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo ...@@ -189,3 +191,15 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
pass pass
raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.") raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.")
def _load_waveform(
root: str,
filename: str,
exp_sample_rate: int,
):
path = os.path.join(root, filename)
waveform, sample_rate = torchaudio.load(path)
if exp_sample_rate != sample_rate:
raise ValueError(f"sample rate should be {exp_sample_rate}, but got {sample_rate}")
return waveform
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment