"macapp/src/app.css" did not exist on "3cd59936a2f92ed217a032f177faad9fb500c4c3"
Commit ad2b61d7 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add metadata mode for various datasets (#2697)

Summary:
Add metadata mode for the following SUPERB benchmark datasets
- QUESST14
- Fluent Speech Commands
- VoxCeleb1

follow ups:
- Add metadata mode for LibriMix -- waiting for unit tests to merge
- Add IEMOCAP + SNIPS datasets

Pull Request resolved: https://github.com/pytorch/audio/pull/2697

Reviewed By: mthrok

Differential Revision: D39666809

Pulled By: carolineechen

fbshipit-source-id: 3a8f07627acceed70f960f47e694efad75b108c2
parent 0b3ddec6
import csv
import os
from pathlib import Path
from typing import Union
from typing import Tuple, Union
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform
SAMPLE_RATE = 16000
class FluentSpeechCommands(Dataset):
......@@ -34,27 +37,39 @@ class FluentSpeechCommands(Dataset):
self.header = data[0]
self.data = data[1:]
def __len__(self):
return len(self.data)
def __getitem__(self, n: int):
"""Load the n-th sample from the dataset.
def get_metadata(self, n: int) -> Tuple[str, int, str, int, str, str, str, str]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample to be loaded
Returns:
(Tensor, int, str, int, str, str, str, str):
``(waveform, sample_rate, file_name, speaker_id, transcription, action, object, location)``
(str, int, str, int, str, str, str, str):
``(filepath, sample_rate, file_name, speaker_id, transcription, action, object, location)``
"""
sample = self.data[n]
file_name = sample[self.header.index("path")].split("/")[-1]
file_name = file_name.split(".")[0]
speaker_id, transcription, action, obj, location = sample[2:]
file_path = os.path.join("wavs", "speakers", speaker_id, f"{file_name}.wav")
return file_path, SAMPLE_RATE, file_name, speaker_id, transcription, action, obj, location
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, str, str, str, str]:
"""Load the n-th sample from the dataset.
wav_path = os.path.join(self._path, "wavs", "speakers", speaker_id, f"{file_name}.wav")
wav, sample_rate = torchaudio.load(wav_path)
Args:
n (int): The index of the sample to be loaded
return wav, sample_rate, file_name, speaker_id, transcription, action, obj, location
Returns:
(Tensor, int, str, int, str, str, str, str):
``(waveform, sample_rate, file_name, speaker_id, transcription, action, object, location)``
"""
metadata = self.get_metadata(n)
waveform = _load_waveform(self._path, metadata[0], metadata[1])
return (waveform,) + metadata[1:]
......@@ -4,13 +4,13 @@ from pathlib import Path
from typing import Optional, Tuple, Union
import torch
import torchaudio
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _load_waveform, extract_archive
URL = "https://speech.fit.vutbr.cz/files/quesst14Database.tgz"
SAMPLE_RATE = 8000
_CHECKSUM = "4f869e06bc066bbe9c5dde31dbd3909a0870d70291110ebbb38878dcbc2fc5e4"
_LANGUAGES = [
"albanian",
......@@ -71,10 +71,20 @@ class QUESST14(Dataset):
elif subset == "eval":
self.data = filter_audio_paths(self._path, language, "language_key_eval.lst")
def _load_sample(self, n: int) -> Tuple[torch.Tensor, int, str]:
def get_metadata(self, n: int) -> Tuple[str, int, str]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample to be loaded
Returns:
(str, int, str):
``(filepath, sample_rate, file_name)``
"""
audio_path = self.data[n]
wav, sample_rate = torchaudio.load(audio_path)
return wav, sample_rate, audio_path.with_suffix("").name
relpath = os.path.relpath(audio_path, self._path)
return relpath, SAMPLE_RATE, audio_path.with_suffix("").name
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str]:
"""Load the n-th sample from the dataset.
......@@ -85,7 +95,9 @@ class QUESST14(Dataset):
Returns:
(Tensor, int, str): ``(waveform, sample_rate, file_name)``
"""
return self._load_sample(n)
metadata = self.get_metadata(n)
waveform = _load_waveform(self._path, metadata[0], metadata[1])
return (waveform,) + metadata[1:]
def __len__(self) -> int:
return len(self.data)
......
......@@ -2,13 +2,13 @@ import os
from pathlib import Path
from typing import List, Tuple, Union
import torchaudio
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _load_waveform, extract_archive
SAMPLE_RATE = 16000
_ARCHIVE_CONFIGS = {
"dev": {
"archive_name": "vox1_dev_wav.zip",
......@@ -111,6 +111,9 @@ class VoxCeleb1(Dataset):
)
_download_extract_wavs(root)
def get_metadata(self, n: int):
raise NotImplementedError
def __getitem__(self, n: int):
raise NotImplementedError
......@@ -145,6 +148,23 @@ class VoxCeleb1Identification(VoxCeleb1):
download_url_to_file(meta_url, meta_list_path)
self._flist = _get_flist(self._path, meta_list_path, subset)
def get_metadata(self, n: int) -> Tuple[str, int, int, str]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample
Returns:
(str, int, int, str):
``(filepath, sample_rate, speaker_id, file_id)``
"""
file_path = self._flist[n]
file_id = _get_file_id(file_path, self._ext_audio)
speaker_id = file_id.split("-")[0]
speaker_id = int(speaker_id[3:])
return file_path, SAMPLE_RATE, speaker_id, file_id
def __getitem__(self, n: int) -> Tuple[Tensor, int, int, str]:
"""Load the n-th sample from the dataset.
......@@ -155,12 +175,9 @@ class VoxCeleb1Identification(VoxCeleb1):
(Tensor, int, int, str):
``(waveform, sample_rate, speaker_id, file_id)``
"""
file_path = self._flist[n]
file_id = _get_file_id(file_path, self._ext_audio)
speaker_id = file_id.split("-")[0]
speaker_id = int(speaker_id[3:])
waveform, sample_rate = torchaudio.load(os.path.join(self._path, file_path))
return (waveform, sample_rate, speaker_id, file_id)
metadata = self.get_metadata(n)
waveform = _load_waveform(self._path, metadata[0], metadata[1])
return (waveform,) + metadata[1:]
def __len__(self) -> int:
return len(self._flist)
......@@ -190,6 +207,23 @@ class VoxCeleb1Verification(VoxCeleb1):
download_url_to_file(meta_url, meta_list_path)
self._flist = _get_paired_flist(self._path, meta_list_path)
def get_metadata(self, n: int) -> Tuple[str, str, int, int, str, str]:
"""Get metadata for the n-th sample from the dataset. Returns filepaths instead of waveforms,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample
Returns:
(str, str, int, int, str, str):
``(filepath_spk1, filepath_spk2, sample_rate, label, file_id_spk1, file_id_spk2)``
"""
label, file_path_spk1, file_path_spk2 = self._flist[n]
label = int(label)
file_id_spk1 = _get_file_id(file_path_spk1, self._ext_audio)
file_id_spk2 = _get_file_id(file_path_spk2, self._ext_audio)
return file_path_spk1, file_path_spk2, SAMPLE_RATE, label, file_id_spk1, file_id_spk2
def __getitem__(self, n: int) -> Tuple[Tensor, Tensor, int, int, str, str]:
"""Load the n-th sample from the dataset.
......@@ -200,15 +234,10 @@ class VoxCeleb1Verification(VoxCeleb1):
(Tensor, Tensor, int, int, str, str):
``(waveform_spk1, waveform_spk2, sample_rate, label, file_id_spk1, file_id_spk2)``
"""
label, file_path_spk1, file_path_spk2 = self._flist[n]
label = int(label)
file_id_spk1 = _get_file_id(file_path_spk1, self._ext_audio)
file_id_spk2 = _get_file_id(file_path_spk2, self._ext_audio)
waveform_spk1, sample_rate = torchaudio.load(os.path.join(self._path, file_path_spk1))
waveform_spk2, sample_rate2 = torchaudio.load(os.path.join(self._path, file_path_spk2))
if sample_rate != sample_rate2:
raise ValueError(f"`sample_rate` {sample_rate} is not equal to `sample_rate2` {sample_rate2}")
return (waveform_spk1, waveform_spk2, sample_rate, label, file_id_spk1, file_id_spk2)
metadata = self.get_metadata(n)
waveform_spk1 = _load_waveform(self._path, metadata[0], metadata[2])
waveform_spk2 = _load_waveform(self._path, metadata[1], metadata[2])
return (waveform_spk1, waveform_spk2) + metadata[2:]
def __len__(self) -> int:
return len(self._flist)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment