Commit 4a65b050 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add Speech Commands metadata function (#2687)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2687

Reviewed By: mthrok

Differential Revision: D39647596

Pulled By: carolineechen

fbshipit-source-id: 8ff874fc1e828130f6754e83ce1f702ca13dfac0
parent ad15bc71
......@@ -2,11 +2,10 @@ import os
from pathlib import Path
from typing import Tuple, Union
import torchaudio
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _load_waveform, extract_archive
URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriSpeech"
......@@ -133,13 +132,6 @@ class LIBRISPEECH(Dataset):
fileid = self._walker[n]
return _get_librispeech_metadata(fileid, self._archive, self._url, self._ext_audio, self._ext_txt)
def _load_waveform(self, path: str):
path = os.path.join(self._archive, path)
waveform, sample_rate = torchaudio.load(path)
if sample_rate != SAMPLE_RATE:
raise ValueError(f"sample rate should be 16000 (16kHz), but got {sample_rate}.")
return waveform
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset.
......@@ -151,7 +143,7 @@ class LIBRISPEECH(Dataset):
``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
"""
metadata = self.get_metadata(n)
waveform = self._load_waveform(metadata[0])
waveform = _load_waveform(self._archive, metadata[0], metadata[1])
return (waveform,) + metadata[1:]
def __len__(self) -> int:
......
......@@ -2,16 +2,16 @@ import os
from pathlib import Path
from typing import Optional, Tuple, Union
import torchaudio
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _load_waveform, extract_archive
FOLDER_IN_ARCHIVE = "SpeechCommands"
URL = "speech_commands_v0.02"
HASH_DIVIDER = "_nohash_"
EXCEPT_FOLDER = "_background_noise_"
SAMPLE_RATE = 16000
_CHECKSUMS = {
"https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz": "743935421bb51cccdb6bdd152e04c5c70274e935c82119ad7faeec31780d811d", # noqa: E501
"https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz": "af14739ee7dc311471de98f5f9d2c9191b18aedfe957f4a6ff791c709868ff58", # noqa: E501
......@@ -27,9 +27,10 @@ def _load_list(root, *filenames):
return output
def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str, str, int]:
def _get_speechcommands_metadata(filepath: str, path: str) -> Tuple[str, int, str, str, int]:
relpath = os.path.relpath(filepath, path)
label, filename = os.path.split(relpath)
reldir, filename = os.path.split(relpath)
_, label = os.path.split(reldir)
# Besides the officially supported split method for datasets defined by "validation_list.txt"
# and "testing_list.txt" over "speech_commands_v0.0x.tar.gz" archives, an alternative split
# method referred to in paragraph 2-3 of Section 7.1, references 13 and 14 of the original
......@@ -43,9 +44,7 @@ def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str
speaker_id, utterance_number = speaker.split(HASH_DIVIDER)
utterance_number = int(utterance_number)
# Load audio
waveform, sample_rate = torchaudio.load(filepath)
return waveform, sample_rate, label, speaker_id, utterance_number
return relpath, SAMPLE_RATE, label, speaker_id, utterance_number
class SPEECHCOMMANDS(Dataset):
......@@ -93,6 +92,7 @@ class SPEECHCOMMANDS(Dataset):
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
self._archive = os.path.join(root, folder_in_archive)
basename = os.path.basename(url)
archive = os.path.join(root, basename)
......@@ -131,6 +131,20 @@ class SPEECHCOMMANDS(Dataset):
walker = sorted(str(p) for p in Path(self._path).glob("*/*.wav"))
self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w]
def get_metadata(self, n: int) -> Tuple[str, int, str, str, int]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample to be loaded
Returns:
(str, int, str, str, int):
``(filepath, sample_rate, label, speaker_id, utterance_number)``
"""
fileid = self._walker[n]
return _get_speechcommands_metadata(fileid, self._archive)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
"""Load the n-th sample from the dataset.
......@@ -141,8 +155,9 @@ class SPEECHCOMMANDS(Dataset):
(Tensor, int, str, str, int):
``(waveform, sample_rate, label, speaker_id, utterance_number)``
"""
fileid = self._walker[n]
return load_speechcommands_item(fileid, self._path)
metadata = self.get_metadata(n)
waveform = _load_waveform(self._archive, metadata[0], metadata[1])
return (waveform,) + metadata[1:]
def __len__(self) -> int:
return len(self._walker)
......@@ -8,6 +8,8 @@ import warnings
import zipfile
from typing import Any, Iterable, List, Optional
import torchaudio
from torch.utils.model_zoo import tqdm
......@@ -189,3 +191,15 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
pass
raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.")
def _load_waveform(
root: str,
filename: str,
exp_sample_rate: int,
):
path = os.path.join(root, filename)
waveform, sample_rate = torchaudio.load(path)
if exp_sample_rate != sample_rate:
raise ValueError(f"sample rate should be {exp_sample_rate}, but got {sample_rate}")
return waveform
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment