Commit 08d3bb17 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add metadata function for LibriSpeech (#2653)

Summary:
Adding support for metadata mode, requested in https://github.com/pytorch/audio/issues/2539, by adding a public `get_metadata()` function in the dataset. This function can be used directly by users to fetch metadata for individual dataset indices, or users can subclass the dataset and override `__getitem__` with `get_metadata` to create a dataset class that directly handles metadata mode.

Pull Request resolved: https://github.com/pytorch/audio/pull/2653

Reviewed By: nateanl, mthrok

Differential Revision: D39105114

Pulled By: carolineechen

fbshipit-source-id: 6f26f1402a053dffcfcc5d859f87271ed5923348
parent 4a20c412
......@@ -2,43 +2,34 @@ import os
from pathlib import Path
from typing import List, Tuple, Union
import torchaudio
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.librispeech import load_librispeech_item
from torchaudio.datasets.librispeech import _get_librispeech_metadata
from torchaudio.datasets.utils import extract_archive
_ARCHIVE_NAME = "librispeech_finetuning"
_URL = "https://dl.fbaipublicfiles.com/librilight/data/librispeech_finetuning.tgz"
_CHECKSUM = "5d1efdc777b548194d7e09ba89126e2188026df9fd57aa57eb14408d2b2342af"
_SUBSET_MAP = {"10min": ["1h/0"], "1h": ["1h/*"], "10h": ["1h/*", "9h"]}
def _get_fileids_paths(path, subset, _ext_audio) -> List[Tuple[str, str]]:
def _get_fileids_paths(path, folders, _ext_audio) -> List[Tuple[str, str]]:
"""Get the file names and the corresponding file paths without `speaker_id`
and `chapter_id` directories.
The format of path is like:
{root}/{_ARCHIVE_NAME}/1h/[0-5]/[clean, other] or
{root}/{_ARCHIVE_NAME}/9h/[clean, other]
"""
if subset == "10min":
files_paths = [
(os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
for p in Path(path).glob("1h/0/*/*/*/*" + _ext_audio)
]
elif subset in ["1h", "10h"]:
files_paths = [
(os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
for p in Path(path).glob("1h/*/*/*/*/*" + _ext_audio)
]
if subset == "10h":
files_paths += [
(os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
for p in Path(path).glob("9h/*/*/*/*" + _ext_audio)
]
else:
raise ValueError(f"Unsupported subset value. Found {subset}.")
files_paths = sorted(files_paths, key=lambda x: x[0] + x[1])
path = Path(path)
files_paths = []
for folder in folders:
paths = [p.relative_to(path) for p in path.glob(f"{folder}/*/*/*/*{_ext_audio}")]
files_paths += [(str(p.parent.parent.parent), str(p.stem)) for p in paths] # get subset folder and file name
files_paths.sort(key=lambda x: x[0] + x[1])
return files_paths
......@@ -63,8 +54,9 @@ class LibriLightLimited(Dataset):
subset: str = "10min",
download: bool = False,
) -> None:
if subset not in ["10min", "1h", "10h"]:
raise ValueError("`subset` must be one of ['10min', '1h', '10h']")
if subset not in _SUBSET_MAP:
raise ValueError(f"`subset` must be one of {_SUBSET_MAP.keys()}. Found: {subset}")
folders = _SUBSET_MAP[subset]
root = os.fspath(root)
self._path = os.path.join(root, _ARCHIVE_NAME)
......@@ -75,7 +67,7 @@ class LibriLightLimited(Dataset):
if not os.path.isfile(archive):
download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM)
extract_archive(archive)
self._fileids_paths = _get_fileids_paths(self._path, subset, self._ext_audio)
self._fileids_paths = _get_fileids_paths(self._path, folders, self._ext_audio)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset.
......@@ -87,7 +79,9 @@ class LibriLightLimited(Dataset):
``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
"""
file_path, fileid = self._fileids_paths[n]
return load_librispeech_item(fileid, file_path, self._ext_audio, self._ext_txt)
metadata = _get_librispeech_metadata(fileid, self._path, file_path, self._ext_audio, self._ext_txt)
waveform, _ = torchaudio.load(os.path.join(self._path, metadata[0]))
return (waveform,) + metadata[1:]
def __len__(self) -> int:
return len(self._fileids_paths)
......@@ -10,6 +10,7 @@ from torchaudio.datasets.utils import extract_archive
URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriSpeech"
SAMPLE_RATE = 16000
_DATA_SUBSETS = [
"dev-clean",
"dev-other",
......@@ -30,7 +31,7 @@ _CHECKSUMS = {
}
def download_librispeech(root, url):
def _download_librispeech(root, url):
base_url = "http://www.openslr.org/resources/12/"
ext_archive = ".tar.gz"
......@@ -43,20 +44,18 @@ def download_librispeech(root, url):
extract_archive(archive)
def load_librispeech_item(
fileid: str, path: str, ext_audio: str, ext_txt: str
) -> Tuple[Tensor, int, str, int, int, int]:
def _get_librispeech_metadata(
fileid: str, root: str, folder: str, ext_audio: str, ext_txt: str
) -> Tuple[str, int, str, int, int, int]:
speaker_id, chapter_id, utterance_id = fileid.split("-")
# Load audio
# Get audio path and sample rate
fileid_audio = f"{speaker_id}-{chapter_id}-{utterance_id}"
file_audio = fileid_audio + ext_audio
file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)
waveform, sample_rate = torchaudio.load(file_audio)
filepath = os.path.join(folder, speaker_id, chapter_id, f"{fileid_audio}{ext_audio}")
# Load text
file_text = f"{speaker_id}-{chapter_id}{ext_txt}"
file_text = os.path.join(path, speaker_id, chapter_id, file_text)
file_text = os.path.join(root, folder, speaker_id, chapter_id, file_text)
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
......@@ -67,8 +66,8 @@ def load_librispeech_item(
raise FileNotFoundError(f"Translation not found for {fileid_audio}")
return (
waveform,
sample_rate,
filepath,
SAMPLE_RATE,
transcript,
int(speaker_id),
int(chapter_id),
......@@ -102,15 +101,17 @@ class LIBRISPEECH(Dataset):
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
) -> None:
self._url = url
if url not in _DATA_SUBSETS:
raise ValueError(f"Invalid url '{url}' given; please provide one of {_DATA_SUBSETS}.")
root = os.fspath(root)
self._archive = os.path.join(root, folder_in_archive)
self._path = os.path.join(root, folder_in_archive, url)
if not os.path.isdir(self._path):
if download:
download_librispeech(root, url)
_download_librispeech(root, url)
else:
raise RuntimeError(
f"Dataset not found at {self._path}. Please set `download=True` to download the dataset."
......@@ -118,6 +119,27 @@ class LIBRISPEECH(Dataset):
self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
def get_metadata(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample to be loaded
Returns:
(str, int, str, int, int, int):
``(filepath, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
"""
fileid = self._walker[n]
return _get_librispeech_metadata(fileid, self._archive, self._url, self._ext_audio, self._ext_txt)
def _load_waveform(self, path: str):
path = os.path.join(self._archive, path)
waveform, sample_rate = torchaudio.load(path)
if sample_rate != SAMPLE_RATE:
raise ValueError(f"sample rate should be 16000 (16kHz), but got {sample_rate}.")
return waveform
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset.
......@@ -128,8 +150,9 @@ class LIBRISPEECH(Dataset):
(Tensor, int, str, int, int, int):
``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
"""
fileid = self._walker[n]
return load_librispeech_item(fileid, self._path, self._ext_audio, self._ext_txt)
metadata = self.get_metadata(n)
waveform = self._load_waveform(metadata[0])
return (waveform,) + metadata[1:]
def __len__(self) -> int:
return len(self._walker)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment