Commit aebcf6af authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add QUESST14 dataset (#2290)

Summary:
implementation adapted from [s3prl](https://github.com/s3prl/s3prl/blob/master/s3prl/downstream/quesst14_dtw/dataset.py)

modifying the s3prl downstream expert to [this](https://github.com/carolineechen/s3prl/commit/adc91a53d581a604f495f3795a865d84aa17f1a5) using this dataset implementation produces the same results as using the original s3prl pipeline

Pull Request resolved: https://github.com/pytorch/audio/pull/2290

Reviewed By: nateanl

Differential Revision: D35692551

Pulled By: carolineechen

fbshipit-source-id: 035ad161d4cbbd2072411cfdf89984b73a89868c
parent 86100e38
...@@ -112,3 +112,10 @@ YESNO ...@@ -112,3 +112,10 @@ YESNO
.. autoclass:: YESNO .. autoclass:: YESNO
:members: :members:
:special-members: __getitem__ :special-members: __getitem__
QUESST14
~~~~~~~~
.. autoclass:: QUESST14
:members:
:special-members: __getitem__
import os
from collections import defaultdict
from pathlib import Path
from parameterized import parameterized
from torchaudio.datasets import quesst14
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_whitenoise,
save_wav,
)
def _get_filename(folder, index):
if folder == "Audio":
return f"quesst14_{index:05d}.wav"
elif folder == "dev_queries":
return f"quesst14_dev_{index:04d}.wav"
elif folder == "eval_queries":
return f"quesst14_eval_{index:04d}.wav"
return
def _get_key(folder):
folder_key_mapping = {
"Audio": "utterances",
"dev_queries": "dev",
"eval_queries": "eval",
}
return folder_key_mapping[folder]
def _save_sample(dataset_dir, folder, language, index, sample_rate, seed):
# create and save audio samples to corresponding files
path = os.path.join(dataset_dir, folder)
os.makedirs(path, exist_ok=True)
filename = _get_filename(folder, index)
file_path = os.path.join(path, filename)
data = get_whitenoise(
sample_rate=sample_rate,
duration=0.01,
n_channels=1,
seed=seed,
)
save_wav(file_path, data, sample_rate)
sample = (data, Path(file_path).with_suffix("").name)
# add audio files and language data to language key files
scoring_path = os.path.join(dataset_dir, "scoring")
os.makedirs(scoring_path, exist_ok=True)
wav_file = f"quesst14Database/{folder}/{filename}"
line = f"{wav_file} {language}"
key = _get_key(folder)
language_key_file = f"language_key_{key}.lst"
language_key_file = os.path.join(scoring_path, language_key_file)
with open(language_key_file, "a") as f:
f.write(line + "\n")
return sample
def _get_mocked_samples(dataset_dir, folder, sample_rate, seed):
samples_per_language = 2
samples_map = defaultdict(list)
samples_all = []
curr_idx = 0
for language in quesst14._LANGUAGES:
for _ in range(samples_per_language):
sample = _save_sample(dataset_dir, folder, language, curr_idx, sample_rate, seed)
samples_map[language].append(sample)
samples_all.append(sample)
curr_idx += 1
return samples_map, samples_all
def get_mock_dataset(dataset_dir):
"""
dataset_dir: directory to the mocked dataset
"""
os.makedirs(dataset_dir, exist_ok=True)
sample_rate = 8000
audio_seed = 0
dev_seed = 1
eval_seed = 2
mocked_utterances, mocked_utterances_all = _get_mocked_samples(dataset_dir, "Audio", sample_rate, audio_seed)
mocked_dev_samples, mocked_dev_samples_all = _get_mocked_samples(dataset_dir, "dev_queries", sample_rate, dev_seed)
mocked_eval_samples, mocked_eval_samples_all = _get_mocked_samples(
dataset_dir, "eval_queries", sample_rate, eval_seed
)
return (
mocked_utterances,
mocked_dev_samples,
mocked_eval_samples,
mocked_utterances_all,
mocked_dev_samples_all,
mocked_eval_samples_all,
)
class TestQuesst14(TempDirMixin, TorchaudioTestCase):
root_dir = None
backend = "default"
utterances = {}
dev_samples = {}
eval_samples = {}
utterances_all = []
dev_samples_all = []
eval_samples_all = []
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
dataset_dir = os.path.join(cls.root_dir, "quesst14Database")
(
cls.utterances,
cls.dev_samples,
cls.eval_samples,
cls.utterances_all,
cls.dev_samples_all,
cls.eval_samples_all,
) = get_mock_dataset(dataset_dir)
def _testQuesst14(self, dataset, data_samples):
num_samples = 0
for i, (data, name) in enumerate(dataset):
self.assertEqual(data, data_samples[i][0])
assert name == data_samples[i][1]
num_samples += 1
assert num_samples == len(data_samples)
def testQuesst14SubsetDocs(self):
dataset = quesst14.QUESST14(self.root_dir, language=None, subset="docs")
self._testQuesst14(dataset, self.utterances_all)
def testQuesst14SubsetDev(self):
dataset = quesst14.QUESST14(self.root_dir, language=None, subset="dev")
self._testQuesst14(dataset, self.dev_samples_all)
def testQuesst14SubsetEval(self):
dataset = quesst14.QUESST14(self.root_dir, language=None, subset="eval")
self._testQuesst14(dataset, self.eval_samples_all)
@parameterized.expand(quesst14._LANGUAGES)
def testQuesst14DocsSingleLanguage(self, language):
dataset = quesst14.QUESST14(self.root_dir, language=language, subset="docs")
self._testQuesst14(dataset, self.utterances[language])
@parameterized.expand(quesst14._LANGUAGES)
def testQuesst14DevSingleLanguage(self, language):
dataset = quesst14.QUESST14(self.root_dir, language=language, subset="dev")
self._testQuesst14(dataset, self.dev_samples[language])
@parameterized.expand(quesst14._LANGUAGES)
def testQuesst14EvalSingleLanguage(self, language):
dataset = quesst14.QUESST14(self.root_dir, language=language, subset="eval")
self._testQuesst14(dataset, self.eval_samples[language])
...@@ -7,6 +7,7 @@ from .librimix import LibriMix ...@@ -7,6 +7,7 @@ from .librimix import LibriMix
from .librispeech import LIBRISPEECH from .librispeech import LIBRISPEECH
from .libritts import LIBRITTS from .libritts import LIBRITTS
from .ljspeech import LJSPEECH from .ljspeech import LJSPEECH
from .quesst14 import QUESST14
from .speechcommands import SPEECHCOMMANDS from .speechcommands import SPEECHCOMMANDS
from .tedlium import TEDLIUM from .tedlium import TEDLIUM
from .vctk import VCTK_092 from .vctk import VCTK_092
...@@ -27,4 +28,5 @@ __all__ = [ ...@@ -27,4 +28,5 @@ __all__ = [
"LibriMix", "LibriMix",
"LIBRITTS", "LIBRITTS",
"TEDLIUM", "TEDLIUM",
"QUESST14",
] ]
import os
import re
from pathlib import Path
from typing import Tuple, Union, Optional
import torch
import torchaudio
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
URL = "https://speech.fit.vutbr.cz/files/quesst14Database.tgz"
_CHECKSUM = "4f869e06bc066bbe9c5dde31dbd3909a0870d70291110ebbb38878dcbc2fc5e4"
_LANGUAGES = [
"albanian",
"basque",
"czech",
"nnenglish",
"romanian",
"slovak",
]
class QUESST14(Dataset):
"""Create QUESST14 Dataset
Args:
root (str or Path): Root directory where the dataset's top level directory is found
language (str or None, optional): Language to get dataset for.
Options: [None, ``albanian``, ``basque``, ``czech``, `nnenglish``, ``romanian``, ``slovak``].
(default: ``"nnenglish"``)
subset (str): subset of the dataset to use. Options: ["docs", "dev", "eval"].
download (bool, optional): Whether to download the dataset if it is not found at root path.
(default: ``False``)
"""
def __init__(
self,
root: Union[str, Path],
language: Optional[str] = "nnenglish",
subset: Optional[str] = None,
download: bool = False,
) -> None:
assert subset is None or subset in ["docs", "dev", "eval"], "`subset` must be one of ['docs', 'dev', 'eval']"
assert language is None or language in _LANGUAGES, f"`language` must be None or one of {str(_LANGUAGES)}"
# Get string representation of 'root'
root = os.fspath(root)
basename = os.path.basename(URL)
archive = os.path.join(root, basename)
basename = basename.rsplit(".", 2)[0]
self._path = os.path.join(root, basename)
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
if not download:
raise RuntimeError("Dataset not found. Please use `download=True` to download")
download_url_to_file(URL, archive, hash_prefix=_CHECKSUM)
extract_archive(archive, root)
if subset == "docs":
self.data = filter_audio_paths(self._path, language, "language_key_utterances.lst")
elif subset == "dev":
self.data = filter_audio_paths(self._path, language, "language_key_dev.lst")
elif subset == "eval":
self.data = filter_audio_paths(self._path, language, "language_key_eval.lst")
def _load_sample(self, n: int) -> Tuple[torch.Tensor, str]:
audio_path = self.data[n]
wav, _ = torchaudio.load(audio_path)
return wav, audio_path.with_suffix("").name
def __getitem__(self, n: int) -> Tuple[torch.Tensor, str]:
return self._load_sample(n)
def __len__(self) -> int:
return len(self.data)
def filter_audio_paths(
path: str,
language: str,
lst_name: str,
):
"""Extract audio paths for the given language."""
audio_paths = []
path = Path(path)
with open(path / "scoring" / lst_name) as f:
for line in f:
audio_path, lang = line.strip().split()
if language is not None and lang != language:
continue
audio_path = re.sub(r"^.*?\/", "", audio_path)
audio_paths.append(path / audio_path)
return audio_paths
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment