Commit 67e64225 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add metadata for Librimix (#2751)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2751

Reviewed By: nateanl

Differential Revision: D40267874

Pulled By: carolineechen

fbshipit-source-id: 4e45a02c650ed65c05cde82289a400a3be877927
parent c38229d4
import os
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple, Union
import torch import torch
import torchaudio
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
_TASKS_TO_MIXTURE = { _TASKS_TO_MIXTURE = {
"sep_clean": "mix_clean", "sep_clean": "mix_clean",
...@@ -55,45 +54,62 @@ class LibriMix(Dataset): ...@@ -55,45 +54,62 @@ class LibriMix(Dataset):
if mode not in ["max", "min"]: if mode not in ["max", "min"]:
raise ValueError(f'Expect ``mode`` to be one in ["min", "max"]. Found {mode}.') raise ValueError(f'Expect ``mode`` to be one in ["min", "max"]. Found {mode}.')
if sample_rate == 8000: if sample_rate == 8000:
self.root = self.root / "wav8k" / mode / subset mix_dir = self.root / "wav8k" / mode / subset
elif sample_rate == 16000: elif sample_rate == 16000:
self.root = self.root / "wav16k" / mode / subset mix_dir = self.root / "wav16k" / mode / subset
else: else:
raise ValueError(f"Unsupported sample rate. Found {sample_rate}.") raise ValueError(f"Unsupported sample rate. Found {sample_rate}.")
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.task = task self.task = task
self.mix_dir = (self.root / _TASKS_TO_MIXTURE[task]).resolve()
self.mix_dir = mix_dir / _TASKS_TO_MIXTURE[task]
if task == "enh_both": if task == "enh_both":
self.src_dirs = [(self.root / "mix_clean")] self.src_dirs = [(mix_dir / "mix_clean")]
else: else:
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)] self.src_dirs = [(mix_dir / f"s{i+1}") for i in range(num_speakers)]
self.files = [p.name for p in self.mix_dir.glob("*.wav")] self.files = [p.name for p in self.mix_dir.glob("*.wav")]
self.files.sort() self.files.sort()
def _load_audio(self, path) -> torch.Tensor: def _load_sample(self, key) -> Tuple[int, torch.Tensor, List[torch.Tensor]]:
waveform, sample_rate = torchaudio.load(path) metadata = self.get_metadata(key)
if sample_rate != self.sample_rate: mixed = _load_waveform(self.root, metadata[1], metadata[0])
raise ValueError(
f"The dataset contains audio file of sample rate {sample_rate}, "
f"but the requested sample rate is {self.sample_rate}."
)
return waveform
def _load_sample(self, filename) -> SampleType:
mixed = self._load_audio(str(self.mix_dir / filename))
srcs = [] srcs = []
for i, dir_ in enumerate(self.src_dirs): for i, path_ in enumerate(metadata[2]):
src = self._load_audio(str(dir_ / filename)) src = _load_waveform(self.root, path_, metadata[0])
if mixed.shape != src.shape: if mixed.shape != src.shape:
raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}") raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}")
srcs.append(src) srcs.append(src)
return self.sample_rate, mixed, srcs return self.sample_rate, mixed, srcs
def get_metadata(self, key: int) -> Tuple[int, str, List[str]]:
"""Get metadata for the n-th sample from the dataset.
Args:
key (int): The index of the sample to be loaded
Returns:
Tuple of the following items;
int:
Sample rate
str:
Path to mixed audio
List of str:
List of paths to source audios
"""
filename = self.files[key]
mixed_path = os.path.relpath(self.mix_dir / filename, self.root)
srcs_paths = []
for dir_ in self.src_dirs:
src = os.path.relpath(dir_ / filename, self.root)
srcs_paths.append(src)
return self.sample_rate, mixed_path, srcs_paths
def __len__(self) -> int: def __len__(self) -> int:
return len(self.files) return len(self.files)
def __getitem__(self, key: int) -> SampleType: def __getitem__(self, key: int) -> Tuple[int, torch.Tensor, List[torch.Tensor]]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
Args: Args:
...@@ -106,7 +122,7 @@ class LibriMix(Dataset): ...@@ -106,7 +122,7 @@ class LibriMix(Dataset):
Sample rate Sample rate
Tensor: Tensor:
Mixture waveform Mixture waveform
list of Tensors: List of Tensors:
List of source waveforms List of source waveforms
""" """
return self._load_sample(self.files[key]) return self._load_sample(key)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment