Commit ad2b61d7 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add metadata mode for various datasets (#2697)

Summary:
Add metadata mode for the following SUPERB benchmark datasets
- QUESST14
- Fluent Speech Commands
- VoxCeleb1

follow ups:
- Add metadata mode for LibriMix -- waiting for unit tests to merge
- Add IEMOCAP + SNIPS datasets

Pull Request resolved: https://github.com/pytorch/audio/pull/2697

Reviewed By: mthrok

Differential Revision: D39666809

Pulled By: carolineechen

fbshipit-source-id: 3a8f07627acceed70f960f47e694efad75b108c2
parent 0b3ddec6
import csv import csv
import os import os
from pathlib import Path from pathlib import Path
from typing import Union from typing import Tuple, Union
import torchaudio from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform
SAMPLE_RATE = 16000
class FluentSpeechCommands(Dataset): class FluentSpeechCommands(Dataset):
...@@ -34,27 +37,39 @@ class FluentSpeechCommands(Dataset): ...@@ -34,27 +37,39 @@ class FluentSpeechCommands(Dataset):
self.header = data[0] self.header = data[0]
self.data = data[1:] self.data = data[1:]
def __len__(self): def get_metadata(self, n: int) -> Tuple[str, int, str, int, str, str, str, str]:
return len(self.data) """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:func:`__getitem__`.
def __getitem__(self, n: int):
"""Load the n-th sample from the dataset.
Args: Args:
n (int): The index of the sample to be loaded n (int): The index of the sample to be loaded
Returns: Returns:
(Tensor, int, str, int, str, str, str, str): (str, int, str, int, str, str, str, str):
``(waveform, sample_rate, file_name, speaker_id, transcription, action, object, location)`` ``(filepath, sample_rate, file_name, speaker_id, transcription, action, object, location)``
""" """
sample = self.data[n] sample = self.data[n]
file_name = sample[self.header.index("path")].split("/")[-1] file_name = sample[self.header.index("path")].split("/")[-1]
file_name = file_name.split(".")[0] file_name = file_name.split(".")[0]
speaker_id, transcription, action, obj, location = sample[2:] speaker_id, transcription, action, obj, location = sample[2:]
file_path = os.path.join("wavs", "speakers", speaker_id, f"{file_name}.wav")
return file_path, SAMPLE_RATE, file_name, speaker_id, transcription, action, obj, location
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, str, str, str, str]:
"""Load the n-th sample from the dataset.
wav_path = os.path.join(self._path, "wavs", "speakers", speaker_id, f"{file_name}.wav") Args:
wav, sample_rate = torchaudio.load(wav_path) n (int): The index of the sample to be loaded
return wav, sample_rate, file_name, speaker_id, transcription, action, obj, location Returns:
(Tensor, int, str, int, str, str, str, str):
``(waveform, sample_rate, file_name, speaker_id, transcription, action, object, location)``
"""
metadata = self.get_metadata(n)
waveform = _load_waveform(self._path, metadata[0], metadata[1])
return (waveform,) + metadata[1:]
...@@ -4,13 +4,13 @@ from pathlib import Path ...@@ -4,13 +4,13 @@ from pathlib import Path
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import torchaudio
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive from torchaudio.datasets.utils import _load_waveform, extract_archive
URL = "https://speech.fit.vutbr.cz/files/quesst14Database.tgz" URL = "https://speech.fit.vutbr.cz/files/quesst14Database.tgz"
SAMPLE_RATE = 8000
_CHECKSUM = "4f869e06bc066bbe9c5dde31dbd3909a0870d70291110ebbb38878dcbc2fc5e4" _CHECKSUM = "4f869e06bc066bbe9c5dde31dbd3909a0870d70291110ebbb38878dcbc2fc5e4"
_LANGUAGES = [ _LANGUAGES = [
"albanian", "albanian",
...@@ -71,10 +71,20 @@ class QUESST14(Dataset): ...@@ -71,10 +71,20 @@ class QUESST14(Dataset):
elif subset == "eval": elif subset == "eval":
self.data = filter_audio_paths(self._path, language, "language_key_eval.lst") self.data = filter_audio_paths(self._path, language, "language_key_eval.lst")
def _load_sample(self, n: int) -> Tuple[torch.Tensor, int, str]: def get_metadata(self, n: int) -> Tuple[str, int, str]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample to be loaded
Returns:
(str, int, str):
``(filepath, sample_rate, file_name)``
"""
audio_path = self.data[n] audio_path = self.data[n]
wav, sample_rate = torchaudio.load(audio_path) relpath = os.path.relpath(audio_path, self._path)
return wav, sample_rate, audio_path.with_suffix("").name return relpath, SAMPLE_RATE, audio_path.with_suffix("").name
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str]: def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
...@@ -85,7 +95,9 @@ class QUESST14(Dataset): ...@@ -85,7 +95,9 @@ class QUESST14(Dataset):
Returns: Returns:
(Tensor, int, str): ``(waveform, sample_rate, file_name)`` (Tensor, int, str): ``(waveform, sample_rate, file_name)``
""" """
return self._load_sample(n) metadata = self.get_metadata(n)
waveform = _load_waveform(self._path, metadata[0], metadata[1])
return (waveform,) + metadata[1:]
def __len__(self) -> int: def __len__(self) -> int:
return len(self.data) return len(self.data)
......
...@@ -2,13 +2,13 @@ import os ...@@ -2,13 +2,13 @@ import os
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple, Union
import torchaudio
from torch import Tensor from torch import Tensor
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive from torchaudio.datasets.utils import _load_waveform, extract_archive
SAMPLE_RATE = 16000
_ARCHIVE_CONFIGS = { _ARCHIVE_CONFIGS = {
"dev": { "dev": {
"archive_name": "vox1_dev_wav.zip", "archive_name": "vox1_dev_wav.zip",
...@@ -111,6 +111,9 @@ class VoxCeleb1(Dataset): ...@@ -111,6 +111,9 @@ class VoxCeleb1(Dataset):
) )
_download_extract_wavs(root) _download_extract_wavs(root)
def get_metadata(self, n: int):
raise NotImplementedError
def __getitem__(self, n: int): def __getitem__(self, n: int):
raise NotImplementedError raise NotImplementedError
...@@ -145,6 +148,23 @@ class VoxCeleb1Identification(VoxCeleb1): ...@@ -145,6 +148,23 @@ class VoxCeleb1Identification(VoxCeleb1):
download_url_to_file(meta_url, meta_list_path) download_url_to_file(meta_url, meta_list_path)
self._flist = _get_flist(self._path, meta_list_path, subset) self._flist = _get_flist(self._path, meta_list_path, subset)
def get_metadata(self, n: int) -> Tuple[str, int, int, str]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample
Returns:
(str, int, int, str):
``(filepath, sample_rate, speaker_id, file_id)``
"""
file_path = self._flist[n]
file_id = _get_file_id(file_path, self._ext_audio)
speaker_id = file_id.split("-")[0]
speaker_id = int(speaker_id[3:])
return file_path, SAMPLE_RATE, speaker_id, file_id
def __getitem__(self, n: int) -> Tuple[Tensor, int, int, str]: def __getitem__(self, n: int) -> Tuple[Tensor, int, int, str]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
...@@ -155,12 +175,9 @@ class VoxCeleb1Identification(VoxCeleb1): ...@@ -155,12 +175,9 @@ class VoxCeleb1Identification(VoxCeleb1):
(Tensor, int, int, str): (Tensor, int, int, str):
``(waveform, sample_rate, speaker_id, file_id)`` ``(waveform, sample_rate, speaker_id, file_id)``
""" """
file_path = self._flist[n] metadata = self.get_metadata(n)
file_id = _get_file_id(file_path, self._ext_audio) waveform = _load_waveform(self._path, metadata[0], metadata[1])
speaker_id = file_id.split("-")[0] return (waveform,) + metadata[1:]
speaker_id = int(speaker_id[3:])
waveform, sample_rate = torchaudio.load(os.path.join(self._path, file_path))
return (waveform, sample_rate, speaker_id, file_id)
def __len__(self) -> int: def __len__(self) -> int:
return len(self._flist) return len(self._flist)
...@@ -190,6 +207,23 @@ class VoxCeleb1Verification(VoxCeleb1): ...@@ -190,6 +207,23 @@ class VoxCeleb1Verification(VoxCeleb1):
download_url_to_file(meta_url, meta_list_path) download_url_to_file(meta_url, meta_list_path)
self._flist = _get_paired_flist(self._path, meta_list_path) self._flist = _get_paired_flist(self._path, meta_list_path)
def get_metadata(self, n: int) -> Tuple[str, str, int, int, str, str]:
"""Get metadata for the n-th sample from the dataset. Returns filepaths instead of waveforms,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample
Returns:
(str, str, int, int, str, str):
``(filepath_spk1, filepath_spk2, sample_rate, label, file_id_spk1, file_id_spk2)``
"""
label, file_path_spk1, file_path_spk2 = self._flist[n]
label = int(label)
file_id_spk1 = _get_file_id(file_path_spk1, self._ext_audio)
file_id_spk2 = _get_file_id(file_path_spk2, self._ext_audio)
return file_path_spk1, file_path_spk2, SAMPLE_RATE, label, file_id_spk1, file_id_spk2
def __getitem__(self, n: int) -> Tuple[Tensor, Tensor, int, int, str, str]: def __getitem__(self, n: int) -> Tuple[Tensor, Tensor, int, int, str, str]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
...@@ -200,15 +234,10 @@ class VoxCeleb1Verification(VoxCeleb1): ...@@ -200,15 +234,10 @@ class VoxCeleb1Verification(VoxCeleb1):
(Tensor, Tensor, int, int, str, str): (Tensor, Tensor, int, int, str, str):
``(waveform_spk1, waveform_spk2, sample_rate, label, file_id_spk1, file_id_spk2)`` ``(waveform_spk1, waveform_spk2, sample_rate, label, file_id_spk1, file_id_spk2)``
""" """
label, file_path_spk1, file_path_spk2 = self._flist[n] metadata = self.get_metadata(n)
label = int(label) waveform_spk1 = _load_waveform(self._path, metadata[0], metadata[2])
file_id_spk1 = _get_file_id(file_path_spk1, self._ext_audio) waveform_spk2 = _load_waveform(self._path, metadata[1], metadata[2])
file_id_spk2 = _get_file_id(file_path_spk2, self._ext_audio) return (waveform_spk1, waveform_spk2) + metadata[2:]
waveform_spk1, sample_rate = torchaudio.load(os.path.join(self._path, file_path_spk1))
waveform_spk2, sample_rate2 = torchaudio.load(os.path.join(self._path, file_path_spk2))
if sample_rate != sample_rate2:
raise ValueError(f"`sample_rate` {sample_rate} is not equal to `sample_rate2` {sample_rate2}")
return (waveform_spk1, waveform_spk2, sample_rate, label, file_id_spk1, file_id_spk2)
def __len__(self) -> int: def __len__(self) -> int:
return len(self._flist) return len(self._flist)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment