Commit 08d3bb17 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add metadata function for LibriSpeech (#2653)

Summary:
Adding support for metadata mode, requested in https://github.com/pytorch/audio/issues/2539, by adding a public `get_metadata()` function in the dataset. This function can be used directly by users to fetch metadata for individual dataset indices, or users can subclass the dataset and override `__getitem__` with `get_metadata` to create a dataset class that directly handles metadata mode.

Pull Request resolved: https://github.com/pytorch/audio/pull/2653

Reviewed By: nateanl, mthrok

Differential Revision: D39105114

Pulled By: carolineechen

fbshipit-source-id: 6f26f1402a053dffcfcc5d859f87271ed5923348
parent 4a20c412
...@@ -2,43 +2,34 @@ import os ...@@ -2,43 +2,34 @@ import os
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple, Union
import torchaudio
from torch import Tensor from torch import Tensor
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchaudio.datasets.librispeech import load_librispeech_item from torchaudio.datasets.librispeech import _get_librispeech_metadata
from torchaudio.datasets.utils import extract_archive from torchaudio.datasets.utils import extract_archive
_ARCHIVE_NAME = "librispeech_finetuning" _ARCHIVE_NAME = "librispeech_finetuning"
_URL = "https://dl.fbaipublicfiles.com/librilight/data/librispeech_finetuning.tgz" _URL = "https://dl.fbaipublicfiles.com/librilight/data/librispeech_finetuning.tgz"
_CHECKSUM = "5d1efdc777b548194d7e09ba89126e2188026df9fd57aa57eb14408d2b2342af" _CHECKSUM = "5d1efdc777b548194d7e09ba89126e2188026df9fd57aa57eb14408d2b2342af"
_SUBSET_MAP = {"10min": ["1h/0"], "1h": ["1h/*"], "10h": ["1h/*", "9h"]}
def _get_fileids_paths(path, subset, _ext_audio) -> List[Tuple[str, str]]: def _get_fileids_paths(path, folders, _ext_audio) -> List[Tuple[str, str]]:
"""Get the file names and the corresponding file paths without `speaker_id` """Get the file names and the corresponding file paths without `speaker_id`
and `chapter_id` directories. and `chapter_id` directories.
The format of path is like: The format of path is like:
{root}/{_ARCHIVE_NAME}/1h/[0-5]/[clean, other] or {root}/{_ARCHIVE_NAME}/1h/[0-5]/[clean, other] or
{root}/{_ARCHIVE_NAME}/9h/[clean, other] {root}/{_ARCHIVE_NAME}/9h/[clean, other]
""" """
if subset == "10min":
files_paths = [ path = Path(path)
(os.path.join(os.path.dirname(p), "..", ".."), str(p.stem)) files_paths = []
for p in Path(path).glob("1h/0/*/*/*/*" + _ext_audio) for folder in folders:
] paths = [p.relative_to(path) for p in path.glob(f"{folder}/*/*/*/*{_ext_audio}")]
elif subset in ["1h", "10h"]: files_paths += [(str(p.parent.parent.parent), str(p.stem)) for p in paths] # get subset folder and file name
files_paths = [ files_paths.sort(key=lambda x: x[0] + x[1])
(os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
for p in Path(path).glob("1h/*/*/*/*/*" + _ext_audio)
]
if subset == "10h":
files_paths += [
(os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
for p in Path(path).glob("9h/*/*/*/*" + _ext_audio)
]
else:
raise ValueError(f"Unsupported subset value. Found {subset}.")
files_paths = sorted(files_paths, key=lambda x: x[0] + x[1])
return files_paths return files_paths
...@@ -63,8 +54,9 @@ class LibriLightLimited(Dataset): ...@@ -63,8 +54,9 @@ class LibriLightLimited(Dataset):
subset: str = "10min", subset: str = "10min",
download: bool = False, download: bool = False,
) -> None: ) -> None:
if subset not in ["10min", "1h", "10h"]: if subset not in _SUBSET_MAP:
raise ValueError("`subset` must be one of ['10min', '1h', '10h']") raise ValueError(f"`subset` must be one of {_SUBSET_MAP.keys()}. Found: {subset}")
folders = _SUBSET_MAP[subset]
root = os.fspath(root) root = os.fspath(root)
self._path = os.path.join(root, _ARCHIVE_NAME) self._path = os.path.join(root, _ARCHIVE_NAME)
...@@ -75,7 +67,7 @@ class LibriLightLimited(Dataset): ...@@ -75,7 +67,7 @@ class LibriLightLimited(Dataset):
if not os.path.isfile(archive): if not os.path.isfile(archive):
download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM) download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM)
extract_archive(archive) extract_archive(archive)
self._fileids_paths = _get_fileids_paths(self._path, subset, self._ext_audio) self._fileids_paths = _get_fileids_paths(self._path, folders, self._ext_audio)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
...@@ -87,7 +79,9 @@ class LibriLightLimited(Dataset): ...@@ -87,7 +79,9 @@ class LibriLightLimited(Dataset):
``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)`` ``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
""" """
file_path, fileid = self._fileids_paths[n] file_path, fileid = self._fileids_paths[n]
return load_librispeech_item(fileid, file_path, self._ext_audio, self._ext_txt) metadata = _get_librispeech_metadata(fileid, self._path, file_path, self._ext_audio, self._ext_txt)
waveform, _ = torchaudio.load(os.path.join(self._path, metadata[0]))
return (waveform,) + metadata[1:]
def __len__(self) -> int: def __len__(self) -> int:
return len(self._fileids_paths) return len(self._fileids_paths)
...@@ -10,6 +10,7 @@ from torchaudio.datasets.utils import extract_archive ...@@ -10,6 +10,7 @@ from torchaudio.datasets.utils import extract_archive
URL = "train-clean-100" URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriSpeech" FOLDER_IN_ARCHIVE = "LibriSpeech"
SAMPLE_RATE = 16000
_DATA_SUBSETS = [ _DATA_SUBSETS = [
"dev-clean", "dev-clean",
"dev-other", "dev-other",
...@@ -30,7 +31,7 @@ _CHECKSUMS = { ...@@ -30,7 +31,7 @@ _CHECKSUMS = {
} }
def download_librispeech(root, url): def _download_librispeech(root, url):
base_url = "http://www.openslr.org/resources/12/" base_url = "http://www.openslr.org/resources/12/"
ext_archive = ".tar.gz" ext_archive = ".tar.gz"
...@@ -43,20 +44,18 @@ def download_librispeech(root, url): ...@@ -43,20 +44,18 @@ def download_librispeech(root, url):
extract_archive(archive) extract_archive(archive)
def load_librispeech_item( def _get_librispeech_metadata(
fileid: str, path: str, ext_audio: str, ext_txt: str fileid: str, root: str, folder: str, ext_audio: str, ext_txt: str
) -> Tuple[Tensor, int, str, int, int, int]: ) -> Tuple[str, int, str, int, int, int]:
speaker_id, chapter_id, utterance_id = fileid.split("-") speaker_id, chapter_id, utterance_id = fileid.split("-")
# Load audio # Get audio path and sample rate
fileid_audio = f"{speaker_id}-{chapter_id}-{utterance_id}" fileid_audio = f"{speaker_id}-{chapter_id}-{utterance_id}"
file_audio = fileid_audio + ext_audio filepath = os.path.join(folder, speaker_id, chapter_id, f"{fileid_audio}{ext_audio}")
file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)
waveform, sample_rate = torchaudio.load(file_audio)
# Load text # Load text
file_text = f"{speaker_id}-{chapter_id}{ext_txt}" file_text = f"{speaker_id}-{chapter_id}{ext_txt}"
file_text = os.path.join(path, speaker_id, chapter_id, file_text) file_text = os.path.join(root, folder, speaker_id, chapter_id, file_text)
with open(file_text) as ft: with open(file_text) as ft:
for line in ft: for line in ft:
fileid_text, transcript = line.strip().split(" ", 1) fileid_text, transcript = line.strip().split(" ", 1)
...@@ -67,8 +66,8 @@ def load_librispeech_item( ...@@ -67,8 +66,8 @@ def load_librispeech_item(
raise FileNotFoundError(f"Translation not found for {fileid_audio}") raise FileNotFoundError(f"Translation not found for {fileid_audio}")
return ( return (
waveform, filepath,
sample_rate, SAMPLE_RATE,
transcript, transcript,
int(speaker_id), int(speaker_id),
int(chapter_id), int(chapter_id),
...@@ -102,15 +101,17 @@ class LIBRISPEECH(Dataset): ...@@ -102,15 +101,17 @@ class LIBRISPEECH(Dataset):
folder_in_archive: str = FOLDER_IN_ARCHIVE, folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False, download: bool = False,
) -> None: ) -> None:
self._url = url
if url not in _DATA_SUBSETS: if url not in _DATA_SUBSETS:
raise ValueError(f"Invalid url '{url}' given; please provide one of {_DATA_SUBSETS}.") raise ValueError(f"Invalid url '{url}' given; please provide one of {_DATA_SUBSETS}.")
root = os.fspath(root) root = os.fspath(root)
self._archive = os.path.join(root, folder_in_archive)
self._path = os.path.join(root, folder_in_archive, url) self._path = os.path.join(root, folder_in_archive, url)
if not os.path.isdir(self._path): if not os.path.isdir(self._path):
if download: if download:
download_librispeech(root, url) _download_librispeech(root, url)
else: else:
raise RuntimeError( raise RuntimeError(
f"Dataset not found at {self._path}. Please set `download=True` to download the dataset." f"Dataset not found at {self._path}. Please set `download=True` to download the dataset."
...@@ -118,6 +119,27 @@ class LIBRISPEECH(Dataset): ...@@ -118,6 +119,27 @@ class LIBRISPEECH(Dataset):
self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio)) self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
def get_metadata(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample to be loaded
Returns:
(str, int, str, int, int, int):
``(filepath, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
"""
fileid = self._walker[n]
return _get_librispeech_metadata(fileid, self._archive, self._url, self._ext_audio, self._ext_txt)
def _load_waveform(self, path: str):
path = os.path.join(self._archive, path)
waveform, sample_rate = torchaudio.load(path)
if sample_rate != SAMPLE_RATE:
raise ValueError(f"sample rate should be 16000 (16kHz), but got {sample_rate}.")
return waveform
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
...@@ -128,8 +150,9 @@ class LIBRISPEECH(Dataset): ...@@ -128,8 +150,9 @@ class LIBRISPEECH(Dataset):
(Tensor, int, str, int, int, int): (Tensor, int, str, int, int, int):
``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)`` ``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
""" """
fileid = self._walker[n] metadata = self.get_metadata(n)
return load_librispeech_item(fileid, self._path, self._ext_audio, self._ext_txt) waveform = self._load_waveform(metadata[0])
return (waveform,) + metadata[1:]
def __len__(self) -> int: def __len__(self) -> int:
return len(self._walker) return len(self._walker)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment