Commit af9cab3b authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add LibriLightLimited dataset (#2302)

Summary:
The `LibriLightLimited` dataset is created for fine-tuning SSL models, such as Wav2Vec2 and HuBERT. It is a supervised subset of [Libri-Light](https://github.com/facebookresearch/libri-light) dataset. To distinguish the unsupervised subset and the supervised one, it's clearer to put it in a separate dataset class for fine-tuning purpose.
It contains "10 min", "1 hour", "10 hour" splits.

Pull Request resolved: https://github.com/pytorch/audio/pull/2302

Reviewed By: mthrok

Differential Revision: D36388188

Pulled By: nateanl

fbshipit-source-id: ba49f1c9996be17db5db41127d8ca96224c94249
parent 48a0c17a
......@@ -66,6 +66,14 @@ LIBRISPEECH
:special-members: __getitem__
LibriLightLimited
~~~~~~~~~~~~~~~~~
.. autoclass:: LibriLightLimited
:members:
:special-members: __getitem__
LIBRITTS
~~~~~~~~
......
import os
from torchaudio.datasets import librilight_limited
from torchaudio_unittest.common_utils import (
get_whitenoise,
save_wav,
TempDirMixin,
TorchaudioTestCase,
)
# Used to generate a unique transcript for each dummy audio file
_NUMBERS = ["ZERO", "ONE", "TWO", "THREE", "FOUR", "FIVE", "SIX", "SEVEN", "EIGHT", "NINE"]
def _save_sample(file_path, speaker_id, chapter_id, utterance_id, sample_rate, seed):
filename = f"{speaker_id}-{chapter_id}-{utterance_id:04d}.flac"
path = os.path.join(file_path, filename)
data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype="float32", seed=seed)
transcript = " ".join([_NUMBERS[x] for x in [speaker_id, chapter_id, utterance_id]])
save_wav(path, data, sample_rate)
sample = (data, sample_rate, transcript, speaker_id, chapter_id, utterance_id)
return sample
def get_mock_dataset(dataset_dir: str):
"""Create mocked dataset for a sub directory.
Args:
dataset_dir (str): the path of the sub directory.
The structure is: audio_type/speaker_id/chapter_id/filename.flac
"""
mocked_data = []
sample_rate = 16000 # 16kHz
seed = 0
for audio_type in ["clean", "other"]:
for speaker_id in range(5):
for chapter_id in range(3):
file_path = os.path.join(dataset_dir, audio_type, str(speaker_id), str(chapter_id))
os.makedirs(file_path, exist_ok=True)
trans_content = []
for utterance_id in range(3):
sample = _save_sample(file_path, speaker_id, chapter_id, utterance_id, sample_rate, seed)
trans_content.append(f"{sample[3]}-{sample[4]}-{sample[5]:04d} {sample[2]}")
mocked_data.append(sample)
seed += 1
trans_filename = f"{speaker_id}-{chapter_id}.trans.txt"
trans_path = os.path.join(file_path, trans_filename)
with open(trans_path, "w") as f:
f.write("\n".join(trans_content))
return mocked_data
def get_mock_datasets(root_dir):
"""
root_dir: directory to the mocked dataset
"""
mocked_data_10min, mocked_data_1h, mocked_data_10h = [], [], []
dataset_dir = os.path.join(root_dir, "librispeech_finetuning", "1h", "0")
os.makedirs(dataset_dir, exist_ok=True)
mocked_data_10min = get_mock_dataset(dataset_dir)
mocked_data_1h += mocked_data_10min
for i in range(1, 6):
dataset_dir = os.path.join(root_dir, "librispeech_finetuning", "1h", str(i))
os.makedirs(dataset_dir, exist_ok=True)
mocked_data_1h += get_mock_dataset(dataset_dir)
mocked_data_10h += mocked_data_1h
dataset_dir = os.path.join(root_dir, "librispeech_finetuning", "9h")
os.makedirs(dataset_dir, exist_ok=True)
mocked_data_10h += get_mock_dataset(dataset_dir)
return mocked_data_10min, mocked_data_1h, mocked_data_10h
class TestLibriLightLimited(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
samples_10min = []
samples_1h = []
samples_10h = []
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
(cls.samples_10min, cls.samples_1h, cls.samples_10h) = get_mock_datasets(cls.root_dir)
def _test_librilightlimited(self, dataset, samples):
num_samples = 0
for i, (data, sample_rate, transcript, speaker_id, chapter_id, utterance_id) in enumerate(dataset):
self.assertEqual(data, samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == samples[i][1]
assert transcript == samples[i][2]
assert speaker_id == samples[i][3]
assert chapter_id == samples[i][4]
assert utterance_id == samples[i][5]
num_samples += 1
assert num_samples == len(samples)
def test_librilightlimited_10min(self):
dataset = librilight_limited.LibriLightLimited(self.root_dir, subset="10min")
self._test_librilightlimited(dataset, self.samples_10min)
def test_librilightlimited_1h(self):
dataset = librilight_limited.LibriLightLimited(self.root_dir, subset="1h")
self._test_librilightlimited(dataset, self.samples_1h)
def test_librilightlimited_10h(self):
dataset = librilight_limited.LibriLightLimited(self.root_dir, subset="10h")
self._test_librilightlimited(dataset, self.samples_10h)
......@@ -3,6 +3,7 @@ from .cmudict import CMUDict
from .commonvoice import COMMONVOICE
from .dr_vctk import DR_VCTK
from .gtzan import GTZAN
from .librilight_limited import LibriLightLimited
from .librimix import LibriMix
from .librispeech import LIBRISPEECH
from .libritts import LIBRITTS
......@@ -17,6 +18,7 @@ from .yesno import YESNO
__all__ = [
"COMMONVOICE",
"LIBRISPEECH",
"LibriLightLimited",
"SPEECHCOMMANDS",
"VCTK_092",
"DR_VCTK",
......
import os
from pathlib import Path
from typing import List, Tuple, Union
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.librispeech import load_librispeech_item
from torchaudio.datasets.utils import extract_archive
_ARCHIVE_NAME = "librispeech_finetuning"
_URL = "https://dl.fbaipublicfiles.com/librilight/data/librispeech_finetuning.tgz"
_CHECKSUM = "5d1efdc777b548194d7e09ba89126e2188026df9fd57aa57eb14408d2b2342af"
def _get_fileids_paths(path, subset, _ext_audio) -> List[Tuple[str, str]]:
"""Get the file names and the corresponding file paths without `speaker_id`
and `chapter_id` directories.
The format of path is like:
{root}/{_ARCHIVE_NAME}/1h/[0-5]/[clean, other] or
{root}/{_ARCHIVE_NAME}/9h/[clean, other]
"""
if subset == "10min":
files_paths = [
(os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
for p in Path(path).glob("1h/0/*/*/*/*" + _ext_audio)
]
elif subset in ["1h", "10h"]:
files_paths = [
(os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
for p in Path(path).glob("1h/*/*/*/*/*" + _ext_audio)
]
if subset == "10h":
files_paths += [
(os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
for p in Path(path).glob("9h/*/*/*/*" + _ext_audio)
]
else:
raise ValueError(f"Unsupported subset value. Found {subset}.")
files_paths = sorted(files_paths, key=lambda x: x[0] + x[1])
return files_paths
class LibriLightLimited(Dataset):
"""Create a Dataset for LibriLightLimited, which is the supervised subset of
LibriLight dataset.
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
subset (str, optional): The subset to use. Options: [``10min`, ``1h``, ``10h``]
(Default: ``10min``).
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
_ext_txt = ".trans.txt"
_ext_audio = ".flac"
def __init__(
self,
root: Union[str, Path],
subset: str = "10min",
download: bool = False,
) -> None:
assert subset in ["10min", "1h", "10h"], "`subset` must be one of ['10min', '1h', '10h']"
root = os.fspath(root)
self._path = os.path.join(root, _ARCHIVE_NAME)
archive = os.path.join(root, f"{_ARCHIVE_NAME}.tgz")
if not os.path.isdir(self._path):
if not download:
raise RuntimeError("Dataset not found. Please use `download=True` to download")
if not os.path.isfile(archive):
download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM)
extract_archive(archive)
self._fileids_paths = _get_fileids_paths(self._path, subset, self._ext_audio)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
(Tensor, int, str, int, int, int):
``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
"""
file_path, fileid = self._fileids_paths[n]
return load_librispeech_item(fileid, file_path, self._ext_audio, self._ext_txt)
def __len__(self) -> int:
return len(self._fileids_paths)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment