Commit 84187909 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add Snips Dataset (#2738)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2738

Reviewed By: carolineechen

Differential Revision: D40238099

Pulled By: nateanl

fbshipit-source-id: c5cc94c2a348a6ef34c04b8dd26114ecb874d73e
parent c5b8e585
......@@ -40,6 +40,7 @@ For example:
LJSPEECH
MUSDB_HQ
QUESST14
Snips
SPEECHCOMMANDS
TEDLIUM
VCTK_092
......
......@@ -433,3 +433,9 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop
journal = {Language Resources and Evaluation},
doi = {10.1007/s10579-008-9076-6}
}
@article{coucke2018snips,
title={Snips voice platform: an embedded spoken language understanding system for private-by-design voice interfaces},
author={Coucke, Alice and Saade, Alaa and Ball, Adrien and Bluche, Th{\'e}odore and Caulier, Alexandre and Leroy, David and Doumouro, Cl{\'e}ment and Gisselbrecht, Thibault and Caltagirone, Francesco and Lavril, Thibaut and others},
journal={arXiv preprint arXiv:1805.10190},
year={2018}
}
import os
from torchaudio.datasets import snips
from torchaudio_unittest.common_utils import get_whitenoise, save_wav, TempDirMixin, TorchaudioTestCase
_SAMPLE_RATE = 16000
_SPEAKERS = [
"Aditi",
"Amy",
"Brian",
"Emma",
"Geraint",
"Ivy",
"Joanna",
"Joey",
"Justin",
"Kendra",
"Kimberly",
"Matthew",
"Nicole",
"Raveena",
"Russell",
"Salli",
]
def _save_wav(filepath: str, seed: int):
wav = get_whitenoise(
sample_rate=_SAMPLE_RATE,
duration=0.01,
n_channels=1,
seed=seed,
)
save_wav(filepath, wav, _SAMPLE_RATE)
return wav
def _save_label(label_path: str, wav_stem: str, label: str):
with open(label_path, "a") as f:
f.write(f"{wav_stem} {label}\n")
def _get_mocked_samples(dataset_dir: str, subset: str, seed: int):
samples = []
subset_dir = os.path.join(dataset_dir, subset)
label_path = os.path.join(dataset_dir, "all.iob.snips.txt")
os.makedirs(subset_dir, exist_ok=True)
num_utterance_per_split = 10
for spk in _SPEAKERS:
for i in range(num_utterance_per_split):
wav_stem = f"{spk}-snips-{subset}-{i}"
wav_path = os.path.join(subset_dir, f"{wav_stem}.wav")
waveform = _save_wav(wav_path, seed)
transcript, iob, intent = f"{spk}XXX", f"{spk}YYY", f"{spk}ZZZ"
label = "BOS " + transcript + " EOS\tO " + iob + " " + intent
_save_label(label_path, wav_stem, label)
samples.append((waveform, _SAMPLE_RATE, transcript, iob, intent))
return samples
def get_mock_datasets(dataset_dir):
"""
dataset_dir: directory to the mocked dataset
"""
os.makedirs(dataset_dir, exist_ok=True)
train_seed = 0
valid_seed = 1
test_seed = 2
mocked_train_samples = _get_mocked_samples(dataset_dir, "train", train_seed)
mocked_valid_samples = _get_mocked_samples(dataset_dir, "valid", valid_seed)
mocked_test_samples = _get_mocked_samples(dataset_dir, "test", test_seed)
return (
mocked_train_samples,
mocked_valid_samples,
mocked_test_samples,
)
class TestSnips(TempDirMixin, TorchaudioTestCase):
root_dir = None
backend = "default"
train_samples = {}
valid_samples = {}
test_samples = {}
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
dataset_dir = os.path.join(cls.root_dir, "SNIPS")
(
cls.train_samples,
cls.valid_samples,
cls.test_samples,
) = get_mock_datasets(dataset_dir)
def _testSnips(self, dataset, data_samples):
num_samples = 0
for i, (data, sample_rate, transcript, iob, intent) in enumerate(dataset):
self.assertEqual(data, data_samples[i][0])
assert sample_rate == data_samples[i][1]
assert transcript == data_samples[i][2]
assert iob == data_samples[i][3]
assert intent == data_samples[i][4]
num_samples += 1
assert num_samples == len(data_samples)
def testSnipsTrain(self):
dataset = snips.Snips(self.root_dir, subset="train", audio_format="wav")
self._testSnips(dataset, self.train_samples)
def testSnipsValid(self):
dataset = snips.Snips(self.root_dir, subset="valid", audio_format="wav")
self._testSnips(dataset, self.valid_samples)
def testSnipsTest(self):
dataset = snips.Snips(self.root_dir, subset="test", audio_format="wav")
self._testSnips(dataset, self.test_samples)
......@@ -12,6 +12,7 @@ from .libritts import LIBRITTS
from .ljspeech import LJSPEECH
from .musdb_hq import MUSDB_HQ
from .quesst14 import QUESST14
from .snips import Snips
from .speechcommands import SPEECHCOMMANDS
from .tedlium import TEDLIUM
from .vctk import VCTK_092
......@@ -40,4 +41,5 @@ __all__ = [
"VoxCeleb1Identification",
"VoxCeleb1Verification",
"IEMOCAP",
"Snips",
]
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union
import torch
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform
_SAMPLE_RATE = 16000
_SPEAKERS = [
"Aditi",
"Amy",
"Brian",
"Emma",
"Geraint",
"Ivy",
"Joanna",
"Joey",
"Justin",
"Kendra",
"Kimberly",
"Matthew",
"Nicole",
"Raveena",
"Russell",
"Salli",
]
def _load_labels(file: Path, subset: str):
"""Load transcirpt, iob, and intent labels for all utterances.
Args:
file (Path): The path to the label file.
subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``].
Returns:
Dictionary of labels, where the key is the filename of the audio,
and the label is a Tuple of transcript, Inside–outside–beginning (IOB) label, and intention label.
"""
labels = {}
with open(file, "r") as f:
for line in f:
line = line.strip().split(" ")
index = line[0]
trans, iob_intent = " ".join(line[1:]).split("\t")
trans = " ".join(trans.split(" ")[1:-1])
iob = " ".join(iob_intent.split(" ")[1:-1])
intent = iob_intent.split(" ")[-1]
if subset in index:
labels[index] = (trans, iob, intent)
return labels
class Snips(Dataset):
"""*Snips* :cite:`coucke2018snips` dataset.
Args:
root (str or Path): Root directory where the dataset's top level directory is found.
subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``].
speakers (List[str] or None, optional): The speaker list to include in the dataset. If ``None``,
include all speakers in the subset. (Default: ``None``)
audio_format (str, optional): The extension of the audios. Options: [``"mp3"``, ``"wav"``].
(Default: ``"mp3"``)
"""
_trans_file = "all.iob.snips.txt"
def __init__(
self,
root: Union[str, Path],
subset: str,
speakers: Optional[List[str]] = None,
audio_format: str = "mp3",
) -> None:
if subset not in ["train", "valid", "test"]:
raise ValueError('`subset` must be one of ["train", "valid", "test"].')
if audio_format not in ["mp3", "wav"]:
raise ValueError('`audio_format` must be one of ["mp3", "wav].')
root = Path(root)
self._path = root / "SNIPS"
self.audio_path = self._path / subset
if speakers is None:
speakers = _SPEAKERS
if not os.path.isdir(self._path):
raise RuntimeError("Dataset not found.")
self.audio_paths = self.audio_path.glob(f"*.{audio_format}")
self.data = []
for audio_path in sorted(self.audio_paths):
audio_name = str(audio_path.name)
speaker = audio_name.split("-")[0]
if speaker in speakers:
self.data.append(audio_path)
transcript_path = self._path / self._trans_file
self.labels = _load_labels(transcript_path, subset)
def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample to be loaded.
Returns:
Tuple of the following items:
str:
Path to audio
int:
Sample rate
str:
Transcription of audio
str:
Inside–outside–beginning (IOB) label of transcription
str:
Intention label of the audio.
"""
audio_path = self.data[n]
relpath = os.path.relpath(audio_path, self._path)
file_name = audio_path.with_suffix("").name
transcript, iob, intent = self.labels[file_name]
return relpath, _SAMPLE_RATE, transcript, iob, intent
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
Tuple of the following items:
Tensor:
Waveform
int:
Sample rate
str:
Transcription of audio
str:
Inside–outside–beginning (IOB) label of transcription
str:
Intention label of the audio.
"""
metadata = self.get_metadata(n)
waveform = _load_waveform(self._path, metadata[0], metadata[1])
return (waveform,) + metadata[1:]
def __len__(self) -> int:
return len(self.data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment